In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
    
    def forward(self, Q, K, V, mask=None):
        # shape(Q) = [B x num_heads x Q_len x D/num_heads]
        # shape(K, V) = [B x num_heads x KV_len x D/num_heads]

        # reshaped(K) = [B x num_heads x D/num_heads x KV_len]
        Q_K_matmul = torch.matmul(Q, K.permute(0, 1, 3, 2))
        scores = Q_K_matmul/math.sqrt(self.d_model)
        # shape(scores) = [B x num_heads x Q_len x KV_len]

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        # shape(attention_weights) = [B x num_heads x Q_len x KV_len]

        output = torch.matmul(attention_weights, V)
        # shape(output) = [B x num_heads x Q_len x D/num_heads]

        return output, attention_weights

##### Example 1

In [3]:
batch_size = 10

In [4]:
d_model, n_heads = 24, 4

In [5]:
d_head = 24 // 4

In [6]:
self_attention = SelfAttention(d_model=d_model)

In [7]:
q, k, v = torch.randn(batch_size, n_heads, 3, d_head), torch.randn(batch_size, n_heads, 3, d_head), torch.randn(batch_size, n_heads, 3, d_head)

In [8]:
import math
from torch import nn

Write a scale product self-attention that **uses in multi-head attention** from scratch

**Hints**
- The valid shape of matrix multiplication: `[10, 4, x, y]` @ `[10, 4, y, x]`
- Scores divide by (the square root of the dimension of a word embedding)

In [46]:
class _SelfAttention(nn.Module):
    def __init__(self, d_head):
        super().__init__()
        self.d_head = d_head
    
    def forward(self, q, k, v):
        # shape(q, k, v) = [batch_size x num_heads x seq_len x d_model/n_heads]

        # reshaped(k) = [batch_size x num_heads x d_model/n_heads x seq_len]
        reshaped_k = k.permute(0, 1, 3, 2)
                
        Q_K_matmul = torch.matmul(q, reshaped_k)
        # shape(scores) = [batch_size x num_heads x q_len x kv_len]
        scores = Q_K_matmul / math.sqrt(self.d_head)

        # shape(attention_weights) = [batch_size x n_heads x q_len x kv_len]
        attention_weights = F.softmax(scores, dim=-1)

        # shape(output) = [batch_size x n_heads x q_len x d_model/n_heads]
        output = torch.matmul(attention_weights, v)

        return output, attention_weights

`q` contains `10` sentences, each sentence has `3` words. And there're `4` heads. Same for `k` and `v`

In [47]:
q.shape, k.shape, v.shape

(torch.Size([10, 4, 3, 6]),
 torch.Size([10, 4, 3, 6]),
 torch.Size([10, 4, 3, 6]))

In [48]:
self_attention = SelfAttention(d_model=6)

TypeError: __init__() got an unexpected keyword argument 'd_model'

In [None]:
output, attention_weights = self_attention(q, k, v)

In [49]:
output.shape, attention_weights.shape

NameError: name 'output' is not defined

##### Example 2

Write an efficient multi-head attention (one matrix for all heads) from scratch in pytorch

In [28]:
class MultiHeadAttention(nn.Module):
    def __init__(self, attention, d_model, num_heads, dropout):
        super().__init__()
        # d_q, d_k, d_v
        self.d = d_model//num_heads

        self.d_model = d_model
        self.num_heads = num_heads

        self.dropout = nn.Dropout(dropout)
        self.attention = attention

        self.linear_Q = nn.Linear(d_model, d_model)
        self.linear_K = nn.Linear(d_model, d_model)
        self.linear_V = nn.Linear(d_model, d_model)

        self.mha_linear = nn.Linear(d_model, d_model)

    def forward(self, pre_q, pre_k, pre_v, mask=None):
        # shape(x) = [B x seq_len x D]

        Q = self.linear_Q(pre_q)
        K = self.linear_K(pre_k)
        V = self.linear_V(pre_v)
        # shape(Q) = [B x seq_len x D] (if in encoder, seq_len = SRC_seq_len; if in decoder, seq_len = TRG_seq_len)
        # shape(K, V) = [B x seq_len x D] (always SRC_seq_len unless in masked-multihead-attention)

        batch_size = pre_q.shape[0]

        Q = Q.reshape(batch_size, self.num_heads, -1, self.d)
        K = K.reshape(batch_size, self.num_heads, -1, self.d)
        V = V.reshape(batch_size, self.num_heads, -1, self.d)
        # shape(Q) = [B x num_heads x seq_len x D]
        # shape(K, V) = [B x num_heads x seq_len x D]

        # run scaled_dot_product_attention
        output, attn_weights = self.attention(Q, K, V)
        # shape(output) = [B x num_heads x Q_len x D/num_heads]
        # shape(attn_weights) = [B x num_heads x Q_len x KV_len]

        output = output.reshape(batch_size, -1, self.d_model)
        # shape(output) = [B x seq_len x D]

        projection = self.dropout(self.mha_linear(output))

        return projection, attn_weights

In [29]:
d_model = 6

In [30]:
attention = SelfAttention(d_head=d_head)

In [31]:
mha = MultiHeadAttention(
    attention=attention,
    d_model=d_model, num_heads=2,
    dropout=0.3
)

In [32]:
pre_q = torch.randn(1, 3, d_model)

In [33]:
pre_k = torch.randn(1, 3, d_model)

In [34]:
pre_v = torch.randn(1, 3, d_model)

In [35]:
projection, attention_weights = mha(
    pre_q=pre_q, pre_k=pre_k, pre_v=pre_v
)

In [36]:
projection.shape, attention_weights.shape

(torch.Size([1, 3, 6]), torch.Size([1, 2, 3, 3]))

##### Example 2.2: Better variable names

In [37]:
pre_q = torch.randn(10, 3, d_model)

In [38]:
pre_k = torch.randn(10, 3, d_model)

In [39]:
pre_v = torch.randn(10, 3, d_model)

In [40]:
from torch import nn

In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, attention, d_model, n_heads):
        super().__init__()
        
        self.d_model, self.n_heads = d_model, n_heads
        self.attention = attention
        self.d_head = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, n_heads * self.d_head)
        self.w_k = nn.Linear(d_model, n_heads * self.d_head)
        self.w_v = nn.Linear(d_model, n_heads * self.d_head)
        self.mha_linear = nn.Linear(n_heads * self.d_head, d_model)
    
    def split_heads(self, x):
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, self.n_heads, seq_len, self.d_head)
    
    def concat(self, x):
        batch_size, n_heads, seq_len, d_head = x.size()
        d_model = n_heads * d_head
        return x.view(batch_size, seq_len, d_model)

    def forward(self, pre_q, pre_k, pre_v):
        # shape(q, k, v) = [batch_size x seq_len x d_model]
        q, k, v = self.w_q(pre_q), self.w_k(pre_k), self.w_v(pre_v)

        # shape(v, k, v) = [batch_size x n_heads x seq_len x d_model]
        k, v, q = self.split_heads(k), self.split_heads(v), self.split_heads(q)

        # shape(output) = [batch_size x n_heads x q_len x d_model/n_heads]
        # shape(attn_weights) = [batch_size x n_heads x q_len x kv_len]
        output, attn_weights = self.attention(q, k, v)
        
        # shape(output) = [batch_size x seq_len x d_model]
        output = self.concat(output)
        projection = self.mha_linear(output)
        return projection, attn_weights

Write an efficient multi-head attention (one matrix for all heads) from scratch in pytorch

**Hints**
- `SelfAttention` takes `q, k, v` as input in forward pass
- The last linear layer in multi-head attention has shape `(d_model, d_model)`

In [42]:
d_model = 6

In [43]:
attention = SelfAttention(d_model=d_model)

TypeError: __init__() got an unexpected keyword argument 'd_model'

In [44]:
mha = MultiHeadAttention(
    attention=attention,
    d_model=d_model, n_heads=2,
)

`pre_q` is a batch of `10` sentence, each sentece include `3` words. Same for `pre_k` and `pre_v`

In [45]:
pre_q.shape, pre_k.shape, pre_v.shape

(torch.Size([10, 3, 6]), torch.Size([10, 3, 6]), torch.Size([10, 3, 6]))

In [254]:
projection, attention_weights = mha(
    pre_q=pre_q, pre_k=pre_k, pre_v=pre_v
)

In [255]:
projection.shape, attention_weights.shape

(torch.Size([10, 3, 6]), torch.Size([10, 2, 3, 3]))