In [1]:
import torch
import torch.nn as nn

## Recap

### Self Attention
Refer chapter 01_attention.ipynb

In [32]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, d_model, qkv_bias=False):
        super(SelfAttention, self).__init__()
        self.Wq = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wk = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wv = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # x: (B, T, embed_size)
        queries = self.Wq(x)   # (B, T, d_model)
        keys = self.Wk(x)      # (B, T, d_model)
        values = self.Wv(x)    # (B, T, d_model)
        
        # Compute attention scores
        attention_scores = queries @ keys.transpose(1, 2)   # (B, T, T)

        # Compute attention weights
        attention_weights = torch.softmax( attention_scores / values.shape[-1] ** 0.5, dim=-1)   # (B, T, T)
        attention_weights = self.dropout(attention_weights)

        # Compute context vector
        Z = attention_weights @ values   # (B, T, d_model)

        return Z

batch_size = 2
T = 4096  # Seq_len
embed_size = 4608 
d_model = 4608
X = torch.rand(batch_size, T, embed_size)
selfattention = SelfAttention(embed_size, d_model, qkv_bias=False)
Z = selfattention(X)
print(Z.shape)

torch.Size([2, 4096, 4608])


### Causal Attention
Refer chapter 02_causal_attention.ipynb

In [40]:
class CausalAttention(nn.Module):
    def __init__(self, embed_size, d_model, qkv_bias=False):
        super(CausalAttention, self).__init__()
        self.Wq = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wk = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wv = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # x: (B, T, embed_size)
        queries = self.Wq(x)   # (B, T, d_model)
        keys = self.Wk(x)      # (B, T, d_model)
        values = self.Wv(x)    # (B, T, d_model)
        
        # Compute attention scores
        attention_scores = queries @ keys.transpose(1, 2)   # (B, T, T)

        # Compute masked attention weights
        mask = torch.triu(torch.ones(T, T), diagonal=1)
        masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
        attention_weights = torch.softmax( masked / values.shape[-1] ** 0.5, dim=-1)   # (B, T, T)
        attention_weights = self.dropout(attention_weights)

        # Compute context vector
        Z = attention_weights @ values   # (B, T, d_model)

        return Z

batch_size = 1
T = 4096  # Seq_len
embed_size = 4608
d_model = 4608
X = torch.rand(batch_size, T, embed_size)
causalattention = CausalAttention(embed_size, d_model, qkv_bias=False)
Z = causalattention(X)
print(Z.shape)

torch.Size([1, 4096, 4608])


## Multi Head Attention

Multi-Head refers to dividing attention mechanism into multiple heads. Each head operating independently.
In single causal attention referred as single head attention, there is only one set of attention weights processing input sequentially.


In [52]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, embed_size, d_model, num_heads, qkv_bias=False):
        super(MultiHeadAttentionWrapper, self).__init__()
        self.heads = nn.ModuleList([CausalAttention(embed_size, d_model, qkv_bias) for _ in range(num_heads)])

    def forward(self, x):
        # x: (B, T, embed_size)
        out = [head(x) for head in self.heads]  # [ (B, T, d_model), (B, T, d_model), ... , (B, T, d_model)]
        Z = torch.cat(out, dim=-1) # (B, T, num_heads * d_model)
        return Z
        
batch_size = 1
T = 4096  # Seq_len
embed_size = 4608
d_model = 4608
num_heads = 4
X = torch.rand(batch_size, T, embed_size)
multiheadattention = MultiHeadAttentionWrapper(embed_size, d_model, num_heads ,qkv_bias=False)
Z = multiheadattention(X)
print(Z.shape)

torch.Size([1, 4096, 18432])


Note: The above implementation, multiple heads are processed sequentially in the forward method.
```
out = [head(x) for head in self.heads]
```
We can overcome this sequential implementation by processing the heads in parallel.


## Multi Head Attention With Weight Splits

Splits the input into multiple heads by reshaping the query, key, value tensors.

In [115]:
class MultiHeadAttention(nn.Module):
    def __init__(self, T, embed_size, d_model, num_heads, qkv_bias=False):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by number of heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.Wq = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wk = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wv = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(0.5)
        self.register_buffer('mask', torch.triu(torch.ones(T, T), diagonal=1))

    def forward(self, x):
        # x: (B, T, embed_size)

        B = x.shape[0]
        T = x.shape[1]
        
        queries = self.Wq(x)  # (B, T, d_model)
        keys = self.Wk(x)     # (B, T, d_model)
        values = self.Wv(x)   # (B, T, d_model)

        
        # divide into num_heads
        queries = queries.view(B, T, self.num_heads, self.head_dim) # (B, T, num_heads, head_dim)
        keys = keys.view(B, T, self.num_heads, self.head_dim) # (B, T, num_heads, head_dim)
        values = values.view(B, T, self.num_heads, self.head_dim) # (B, T, num_heads, head_dim)

        queries = queries.transpose(1, 2)  # (B, num_heads, T, head_dim)
        keys = keys.transpose(1, 2)  # (B, num_heads, T, head_dim)
        values = values.transpose(1, 2)  # (B, num_heads, T, head_dim)

        # Compute attention scores
        attention_scores = queries @ keys.transpose(2, 3)  # (B, num_heads, T, T)
        mask_bool = self.mask.bool()  # (T, T)
        attention_scores.masked_fill(mask_bool, -torch.inf) # (B, num_heads, T, T)

        # Compute attention weights
        attention_weigths = torch.softmax( attention_scores / keys.shape[-1] ** 0.5, dim=-1)
        attention_weigths = self.dropout(attention_weigths)  # (B, num_heads, T, T)

        # Compute context vector
        Z = attention_weigths @ values  # (B, num_heads, T, head_dim)
        Z = Z.transpose(1, 2)   # (B, T, num_heads, head_dim)
        Z = Z.contiguous()  # (B, T, num_heads, head_dim)
        Z = Z.view(B, T, d_model)   # (B, T, d_model)

        Z = self.out_proj(Z)
        
        return Z

batch_size = 1
T = 4096  # Seq_len
embed_size = 4608
d_model = 4608
num_heads = 48
X = torch.rand(batch_size, T, embed_size)
multiheadattention = MultiHeadAttention(T, embed_size, d_model, num_heads ,qkv_bias=False)
Z = multiheadattention(X)
print(Z.shape)


torch.Size([1, 4096, 4608])


### Count number of parameters

In [100]:
params = 0
for p in multiheadattention.named_parameters():
    if p[1].requires_grad:
        params += p[1].numel()
print(params / 1e9, 'B')

0.084939264 B


## Just rewriting again

In [118]:
class MultiHeadAttention(nn.Module):
    def __init__(self, T, embed_size, d_model, num_heads, qkv_bias=False):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model should be divisible by num_heads"
        self.T = T
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.Wq = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wk = nn.Linear(embed_size, d_model, bias=qkv_bias)
        self.Wv = nn.Linear(embed_size, d_model, bias=qkv_bias)

        self.dropout = nn.Dropout(0.5)
        self.out_proj = nn.Linear(d_model, d_model)
        self.register_buffer('mask', torch.triu(torch.ones(T, T), diagonal=1))

    def forward(self, x):
        # x: (B, T, embed_size)

        B = x.shape[0]
        
        queries = self.Wq(x)   # (B, T, d_model)
        keys = self.Wk(x)      # (B, T, d_model) 
        values = self.Wv(x)    # (B, T, d_model)

        queries = queries.view(B, self.T, self.num_heads, self.head_dim)  # (B, T, num_heads, head_dim)
        keys = keys.view(B, self.T, self.num_heads, self.head_dim)  # (B, T, num_heads, head_dim)
        values = values.view(B, self.T, self.num_heads, self.head_dim)  # (B, T, num_heads, head_dim)

        queries = queries.transpose(1, 2)  # (B, num_heads, T, head_dim)
        keys = keys.transpose(1, 2)  # (B, num_heads, T, head_dim)
        values = values.transpose(1, 2)  # (B, num_heads, T, head_dim)
        
        # Compute attention scores
        attention_scores = queries @ keys.transpose(2, 3)   # (B, num_heads, T, T)
        attention_scores_masked = attention_scores.masked_fill(self.mask.bool(), -torch.inf)  # (B, num_heads, T, T)

        # Compute attention weights
        attention_weights = torch.softmax( attention_scores_masked / keys.shape[-1] ** 0.5, dim=-1 )  # (B, num_heads, T, T)

        # Compute context vector
        Z = attention_weights @ values  # (B, num_heads, T, head_dim)

        # reshape
        Z = Z.contiguous().view(B, T, d_model)
        Z = self.out_proj(Z)
        
        return Z


B = 1
T = 4096
embed_size = 4608
d_model = 4608
num_heads = 48
model = MultiHeadAttention(T, embed_size, d_model, num_heads)

X = torch.rand(B, T, embed_size)
Z = model(X)            
print(Z.shape)


torch.Size([1, 4096, 4608])
