# Self Attention

## 1. https://github.com/datnnt1997/multi-head_self-attention

In [1]:

import torch
import torch.nn as nn

import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class BertSelfAttention(nn.Module):
      def __init__(self, config):
        super().__init__()
        assert config["hidden_size"] % config["num_of_attention_heads"] == 0, "The hidden size is not a multiple of the number of attention heads"

        self.num_attention_heads = config['num_of_attention_heads']
        self.attention_head_size = int(config['hidden_size'] / config['num_of_attention_heads'])
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config['hidden_size'], self.all_head_size)
        self.key = nn.Linear(config['hidden_size'], self.all_head_size)
        self.value = nn.Linear(config['hidden_size'], self.all_head_size)

        self.dense = nn.Linear(config['hidden_size'], config['hidden_size'])

      def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

      def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)                             # [Batch_size x Seq_length x Hidden_size]
        mixed_key_layer = self.key(hidden_states)                                 # [Batch_size x Seq_length x Hidden_size]
        mixed_value_layer = self.value(hidden_states)                             # [Batch_size x Seq_length x Hidden_size]
        
        query_layer = self.transpose_for_scores(mixed_query_layer)                # [Batch_size x Num_of_heads x Seq_length x Head_size]
        key_layer = self.transpose_for_scores(mixed_key_layer)                    # [Batch_size x Num_of_heads x Seq_length x Head_size]
        value_layer = self.transpose_for_scores(mixed_value_layer)                # [Batch_size x Num_of_heads x Seq_length x Head_size]

        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [Batch_size x Num_of_heads x Seq_length x Seq_length]
        attention_scores = attention_scores / math.sqrt(self.attention_head_size) # [Batch_size x Num_of_heads x Seq_length x Seq_length]
        attention_probs = nn.Softmax(dim=-1)(attention_scores)                    # [Batch_size x Num_of_heads x Seq_length x Seq_length]
        context_layer = torch.matmul(attention_probs, value_layer)                # [Batch_size x Num_of_heads x Seq_length x Head_size]

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()            # [Batch_size x Seq_length x Num_of_heads x Head_size]
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # [Batch_size x Seq_length x Hidden_size]
        context_layer = context_layer.view(*new_context_layer_shape)              # [Batch_size x Seq_length x Hidden_size]
        
        output =  self.dense(context_layer)
        
        return output

In [3]:

config = {
    "num_of_attention_heads": 2,
    "hidden_size": 4
}
     

In [4]:

selfattn = BertSelfAttention(config)
print(selfattn)
     

BertSelfAttention(
  (query): Linear(in_features=4, out_features=4, bias=True)
  (key): Linear(in_features=4, out_features=4, bias=True)
  (value): Linear(in_features=4, out_features=4, bias=True)
  (dense): Linear(in_features=4, out_features=4, bias=True)
)


In [5]:
embed_rand = torch.rand((1,3,4))
print(f"Embed Shape: {embed_rand.shape}")
print(f"Embed Values:\n{embed_rand}")

Embed Shape: torch.Size([1, 3, 4])
Embed Values:
tensor([[[0.0381, 0.2856, 0.5005, 0.9897],
         [0.7127, 0.7762, 0.2505, 0.2072],
         [0.6184, 0.2971, 0.1384, 0.2929]]])


In [6]:

output = selfattn(embed_rand)
print(f"Output Shape: {output.shape}")
print(f"Output Values:\n{output}")

Output Shape: torch.Size([1, 3, 4])
Output Values:
tensor([[[ 0.0489, -0.3140,  0.3379,  0.1220],
         [ 0.0499, -0.3146,  0.3362,  0.1221],
         [ 0.0468, -0.3133,  0.3371,  0.1227]]], grad_fn=<ViewBackward0>)
