In [None]:
import math
import torch
import torch.nn.functional as F
from torch import nn

In [None]:
def scaled_dot_product_attention(Q, K, V, d_k = 4):
    QK = torch.matmul(Q, K.T)

    matmul_scaled = QK / math.sqrt(d_k)

    attention_weights = F.softmax(matmul_scaled, dim=-1)
    
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

In [None]:
temp_k = torch.Tensor([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]])  # (4, 3)

temp_v = torch.Tensor([[   1,0, 1],
                      [  10,0, 2],
                      [ 100,5, 0],
                      [1000,6, 0]])  # (4, 3)

In [None]:
temp_q = torch.Tensor([[0, 10, 0]])  # (1, 3)

In [None]:
scaled_dot_product_attention(temp_q, temp_k, temp_v)

(tensor([[1.0000e+01, 2.1216e-21, 2.0000e+00]]),
 tensor([[1.9287e-22, 1.0000e+00, 1.9287e-22, 1.9287e-22]]))

##### Example 3

`scaled_dot_product_attention` is a function calculate the self-attention

In [None]:
temp_q = torch.Tensor([[0, 10, 0], [0, 0, 10], [10, 10, 0]])

In [None]:
temp_k, temp_v

(tensor([[10.,  0.,  0.],
         [ 0., 10.,  0.],
         [ 0.,  0., 10.],
         [ 0.,  0., 10.]]),
 tensor([[   1.,    0.,    1.],
         [  10.,    0.,    2.],
         [ 100.,    5.,    0.],
         [1000.,    6.,    0.]]))

In [None]:
temp_q

tensor([[ 0., 10.,  0.],
        [ 0.,  0., 10.],
        [10., 10.,  0.]])

In [None]:
output, attention_weights = scaled_dot_product_attention(temp_q, temp_k, temp_v)

Iterept the output of `attention_weights`

In [None]:
torch.round(attention_weights, decimals=3)

tensor([[0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.0000, 0.0000]])

**Explain**

Each row in `attention_weights` corresponds to how similar each vector in `temp_q` similar to each vector in `temp_k`
- The first row: `1.000` at index 1st means => `temp_q[0]` similars to `temp_k[1]`
- The second row: `0.5000` and `0.5000` at index 2nd and 3th and  means => `temp_q[1]` similars to `temp_k[2]` and `temp_k[3]`
- The third row: `0.5000` and `0.5000` at index 0th and 1st and  means => `temp_q[2]` similars to `temp_k[0]` and `temp_k[1]`

In [None]:
torch.round(attention_weights, decimals=3)

tensor([[0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.0000, 0.0000]])

### Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=4, num_heads=2, dropout=0.3):
        super().__init__()
        
        # calculate the dimensionality per head
        self.d_h = d_model // num_heads
        
        assert self.d_h * num_heads == d_model
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        
        # go from d_model to d per head
        self.linear_qs = nn.ModuleList([
            nn.Linear(d_model, self.d_h) for _ in range(num_heads)
        ])
        self.linear_ks = nn.ModuleList([
            nn.Linear(d_model, self.d_h) for _ in range(num_heads)
        ])        
        self.linear_vs = nn.ModuleList([
            nn.Linear(d_model, self.d_h) for _ in range(num_heads)
        ])
        self.linear = nn.Linear(d_model, d_model)
    
    def scaled_dot_product_attention(self, Q, K, V):
        # shape(Q, K, V) = [batch_size x seq_len x d_h] * num_heads
        # shape(Q) = [batch_size x seq_len x d_h]
        # shape(K) = [batch_size x seq_len x d_h] => [batch_size x d_h x seq_len]
        
        # shape(Q_K_matmul) = [batch_size x seq_len x seq_len]
        Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1))
        
        # shape(scores) = [batch_size x seq_len x seq_len]
        scores = Q_K_matmul / math.sqrt(self.d_h)
        
        # shape(attn_weights) = [batch_size x seq_len x seq_len]
        attn_weights = F.softmax(scores, dim=-1)
        
        # shape(output) = [batch_size x seq_len x d_h]
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights
    
    def forward(self, x):
        # shape(x) = [batch_size x seq_len x d_model]
        
        # shape(Q, K, V) = [batch_size x seq_len x d_h] * num_heads
        Q = [linear_q(x) for linear_q in self.linear_qs]
        K = [linear_k(x) for linear_k in self.linear_ks]
        V = [linear_v(x) for linear_v in self.linear_vs]
        
        # shape(output_per_head) = [batch_size x seq_len x d_h] * num_heads
        output_per_head = []
        
        # shape(attn_weight_per_head) = [batch_size x seq_len x seq_len] * num_heads
        attn_weight_per_head = []
        
        for Q_, K_, V_ in zip(Q, K, V):
            output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_)
            output_per_head.append(output)
            attn_weight_per_head.append(attn_weight)
        
        # shape(output) = [batch_size x seq_len x d_model]
        output = torch.cat(output_per_head, dim=-1)
        
        # shape(attn_weights) = [num_heads x batch_size x seq_len x seq_len]
        attn_weights = torch.stack(attn_weight_per_head)
        
        # shape(attn_weights) = [batch_size x num_heads x seq_len x seq_len]
        attn_weights = attn_weights.permute(1, 0, 2, 3)
        
        projection = self.dropout(self.linear(output))
        
        return projection, attn_weights

NameError: name 'nn' is not defined

- **Step 1**: `x` is the input text has shape `[batch_size x seq_len x D]`

- **Step 2**: After `x` go through `Q` => shape `[batch_size x seq_len x d_per_head] * num_heads`


Notations
- `d_model`: the model dimensionality
- `d_h`: the head dimensionality

In [None]:
toy_encodings = torch.Tensor([[
    [0.0, 0.1, 0.2, 0.3],
    [1.0, 1.1, 1.2, 1.3],
    [2.0, 2.1, 2.2, 2.3]
]]) 

In [None]:
toy_encodings.shape

torch.Size([1, 3, 4])

In [None]:
mha = MultiHeadAttention(d_model=4, num_heads=2)

In [None]:
output, attn_weights= mha(toy_encodings)

NameError: name 'mha' is not defined

In [None]:
output.shape

NameError: name 'output' is not defined

In [None]:
attn_weights.shape

torch.Size([1, 2, 3, 3])