In [13]:
import math
import torch
import torch.nn.functional as F
from torch import nn

### Multi-Head Attention (without masking)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=4, num_heads=2, dropout=0.3):
        super().__init__()
        
        # calculate the dimensionality per head
        self.d_h = d_model // num_heads
        
        assert self.d_h * num_heads == d_model
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        
        # go from d_model to d per head
        self.linear_qs = nn.ModuleList([
            nn.Linear(d_model, self.d_h) for _ in range(num_heads)
        ])
        self.linear_ks = nn.ModuleList([
            nn.Linear(d_model, self.d_h) for _ in range(num_heads)
        ])        
        self.linear_vs = nn.ModuleList([
            nn.Linear(d_model, self.d_h) for _ in range(num_heads)
        ])
        self.linear = nn.Linear(d_model, d_model)
    
    def scaled_dot_product_attention(self, Q, K, V):
        # shape(Q, K, V) = [batch_size x seq_len x d_h] * num_heads
        # shape(Q) = [batch_size x seq_len x d_h]
        # shape(K) = [batch_size x seq_len x d_h] => [batch_size x d_h x seq_len]
        
        # shape(Q_K_matmul) = [batch_size x seq_len x seq_len]
        Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1))
        
        # shape(scores) = [batch_size x seq_len x seq_len]
        scores = Q_K_matmul / math.sqrt(self.d_h)
        
        # shape(attn_weights) = [batch_size x seq_len x seq_len]
        attn_weights = F.softmax(scores, dim=-1)
        
        # shape(output) = [batch_size x seq_len x d_h]
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights
    
    def forward(self, x):
        # shape(x) = [batch_size x seq_len x d_model]
        
        # shape(Q, K, V) = [batch_size x seq_len x d_h] * num_heads
        Q = [linear_q(x) for linear_q in self.linear_qs]
        K = [linear_k(x) for linear_k in self.linear_ks]
        V = [linear_v(x) for linear_v in self.linear_vs]
        
        # shape(output_per_head) = [batch_size x seq_len x d_h] * num_heads
        output_per_head = []
        
        # shape(attn_weight_per_head) = [batch_size x seq_len x seq_len] * num_heads
        attn_weight_per_head = []
        
        for Q_, K_, V_ in zip(Q, K, V):
            output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_)
            output_per_head.append(output)
            attn_weight_per_head.append(attn_weight)
        
        # shape(output) = [batch_size x seq_len x d_model]
        output = torch.cat(output_per_head, dim=-1)
        
        # shape(attn_weights) = [num_heads x batch_size x seq_len x seq_len]
        attn_weights = torch.stack(attn_weight_per_head)
        
        # shape(attn_weights) = [batch_size x num_heads x seq_len x seq_len]
        attn_weights = attn_weights.permute(1, 0, 2, 3)
        
        projection = self.dropout(self.linear(output))
        
        return projection, attn_weights

In [None]:
text_encodings = torch.Tensor([[
    [0.0, 0.1, 0.2, 0.3],
    [1.0, 1.1, 1.2, 1.3],
    [2.0, 2.1, 2.2, 2.3]
]]) 

##### Example 8

In [None]:
text_encodings.shape

torch.Size([1, 3, 4])

In [None]:
mha = MultiHeadAttention(d_model=4, num_heads=2)

In [None]:
output, attn_weights= mha(text_encodings)

Predict the shape of `projection`. Explain why.

In [None]:
output.shape

torch.Size([1, 3, 4])

**Explain**

- `1`: batch size is one
- `3`: there are three words in a sequence
- `4`: each word represented by a vector has size 4

##### Example 9

In [None]:
text_encodings.shape

torch.Size([1, 3, 4])

In [None]:
mha = MultiHeadAttention(d_model=4, num_heads=2)

In [None]:
projection, attn_weights= mha(toy_encodings)

Predict the shape of `attn_weights`. Explain why.

In [None]:
attn_weights.shape

torch.Size([1, 2, 3, 3])

**Explain**
- `1`: batch size is one
- `2`: there're two heads
- `3`: there're three words
- `3`: in each words, there are attention weights for other words in the sentence

### Masked Self Attention

##### Example 1

In [44]:
mask = torch.zeros(10, 4, 3, 3)

In [55]:
q, k, v = torch.randn(10, 4, 3, 6), torch.randn(10, 4, 3, 6), torch.randn(10, 4, 3, 6)

In [73]:
import math
from torch import nn
import torch.nn.functional as F

Write a scale product **masked self-attention** that **uses in multi-head attention** from scratch

**Hints**
- The valid shape of matrix multiplication: `[10, 4, x, y]` @ `[10, 4, y, x]`
- Scores divide by (the square root of the dimension of a word embedding)

In [74]:
class SelfAttention(nn.Module):
    def __init__(self, d_head):
        super().__init__()
        self.d_head = d_head
    
    def forward(self, q, k, v, mask = None):
        # shape(q,k,v) = [batch_size x n_heads x seq_len x d_head]
        
        # shape(reshaped_k) = [batch_size x n_heads x d_head x seq_len]
        reshaped_k = k.permute(0, 1, 3, 2)
        q_k_matmul = torch.matmul(q, reshaped_k)
        scores = q_k_matmul / math.sqrt(self.d_head)
                
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # shape(attention_weights) = [batch_size x n_heads x seq_len x d_head]
        attention_weights = F.softmax(scores, dim=-1)
        
        # shape(output) = [batch_size x n_heads x seq_len x d_head]
        output = torch.matmul(attention_weights, v)
        
        return output, attention_weights

`q` contain `10` sentences, each sentence has `3` words. And there're `4` heads.  Same for `k` and `v`

In [75]:
q.shape, k.shape, v.shape

(torch.Size([10, 4, 3, 6]),
 torch.Size([10, 4, 3, 6]),
 torch.Size([10, 4, 3, 6]))

In [76]:
mask.shape

torch.Size([10, 4, 3, 3])

In [77]:
attention = SelfAttention(d_head=6)

In [78]:
output, attention_weights = attention(q, k, v, mask=mask)

In [79]:
output.shape, attention_weights.shape

(torch.Size([10, 4, 3, 6]), torch.Size([10, 4, 3, 3]))

In [72]:
from foundation.transformer.efficient_attention import ScaleDotProductAttention