## Multihead Attention

In this notebook, we will implement the multihead attention mechanism.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        if not d_model%num_heads == 0:
            raise ValueError("D_model should be a multiple of num_heads")
        self.d_model = d_model
        self.num_heads = num_heads
        print(d_model/num_heads)
        self.head_embed_dim = int(d_model/num_heads)
        print(self.head_embed_dim)
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_out = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        print(f"Dimension of x is: {x.shape}") ## shape of x should be (batch_size, seq_len, d_model)\
        batch_size, seq_len, d_model = x.shape
        q_proj = self.W_q(x) ## shape (batch_size, seq_len, d_model)
        k_proj = self.W_k(x) ## shape (batch_size, seq_len, d_model)
        v_proj = self.W_v(x) ## shape (batch_size, seq_len, d_model)
        #out_proj = self.W_out(x) ## shape (batch_size, seq_len, d_model)

        # Muliheads should be of shape (num_heads,batch_size, seq_len, d_model/num_heads)        
        multi_q_proj = q_proj.view(self.num_heads, batch_size, seq_len, self.head_embed_dim)
        multi_k_proj = k_proj.view(self.num_heads, batch_size, seq_len, self.head_embed_dim)
        multi_v_proj = v_proj.view(self.num_heads, batch_size, seq_len, self.head_embed_dim)

        logits = multi_q_proj @ multi_k_proj.transpose(-2,-1) # (num_heads,batch_size, seq_len, seq_len)
        logits_scaled = logits/math.sqrt(float(d_model/self.num_heads)) # (num_heads,batch_size, seq_len, seq_len)

        head_attention_probs = torch.softmax(logits_scaled, dim=-1) # (num_heads,batch_size, seq_len, seq_len)
        head_attention_scores =     head_attention_probs @   multi_v_proj ## (num_heads,batch_size, seq_len, d_model/num_heads)
        attention_score_permuted =   head_attention_scores.permute(1,2,0,3)   ## (batch_size, seq_len, num_heads,d_model/num_heads)
        attention_scores = attention_score_permuted.reshape(batch_size, seq_len, d_model) 
        out = self.W_out(attention_scores)
        print(f"Output DImension: {out.shape}")
        return(out)


multiHeadAttention = MultiHeadAttention(d_model=12, num_heads=3)

4.0
4


In [35]:
input = torch.rand((12,5,12))

In [36]:
multiHeadAttention(input)

Dimension of x is: torch.Size([12, 5, 12])
Output DImension: torch.Size([12, 5, 12])


tensor([[[-0.0664,  0.1414, -0.0788,  0.2497, -0.2256,  0.0281,  0.0019,
          -0.1521,  0.0098, -0.0464,  0.0029,  0.0747],
         [-0.0731,  0.1391, -0.0748,  0.2339, -0.2150,  0.0314,  0.0183,
          -0.1485,  0.0056, -0.0447,  0.0154,  0.0677],
         [-0.0717,  0.1358, -0.0773,  0.2362, -0.2151,  0.0306,  0.0108,
          -0.1434,  0.0072, -0.0412,  0.0092,  0.0712],
         [-0.0696,  0.1381, -0.0798,  0.2450, -0.2192,  0.0336,  0.0134,
          -0.1523,  0.0096, -0.0507,  0.0064,  0.0714],
         [-0.0679,  0.1427, -0.0710,  0.2347, -0.2217,  0.0239,  0.0059,
          -0.1497,  0.0040, -0.0364,  0.0155,  0.0702]],

        [[-0.1322,  0.0795, -0.0826,  0.2003, -0.1383,  0.1013, -0.0308,
          -0.0939,  0.0773, -0.1021, -0.0484,  0.0929],
         [-0.1484,  0.0716, -0.0866,  0.1858, -0.1265,  0.1116, -0.0233,
          -0.0700,  0.0860, -0.1117, -0.0543,  0.0981],
         [-0.1461,  0.0758, -0.0852,  0.1784, -0.1112,  0.1174, -0.0191,
          -0.0740,  0.

## OpenAI Implementation

In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads.")
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # dimension per head

        # Linear layers for Q, K, V (with bias for illustration)
        self.W_q = nn.Linear(d_model, d_model, bias=True)
        self.W_k = nn.Linear(d_model, d_model, bias=True)
        self.W_v = nn.Linear(d_model, d_model, bias=True)

        # Final linear layer to recombine heads
        self.W_out = nn.Linear(d_model, d_model, bias=True)

        # Optional dropout on the attention probabilities
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        x:    (batch_size, seq_len, d_model)
        mask: (batch_size, 1, 1, seq_len) or (batch_size, 1, seq_len, seq_len)
              depending on need. Typically 1/True for keep, 0/False for masked.
        """
        B, L, _ = x.size()  # (Batch, Seq_len, d_model)

        # 1) Linear projections: Q, K, V
        Q = self.W_q(x)  # (B, L, d_model)
        K = self.W_k(x)  # (B, L, d_model)
        V = self.W_v(x)  # (B, L, d_model)

        # 2) Reshape to split heads: (B, L, num_heads, d_k)
        Q = Q.view(B, L, self.num_heads, self.d_k)
        K = K.view(B, L, self.num_heads, self.d_k)
        V = V.view(B, L, self.num_heads, self.d_k)

        # 3) Permute to get (B, num_heads, L, d_k)
        Q = Q.permute(0, 2, 1, 3)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)

        # 4) Scaled dot product attention
        #    Q @ K^T => (B, H, L, d_k) x (B, H, d_k, L) = (B, H, L, L)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 5) Optional masking
        if mask is not None:
            # mask == 0 => fill with -inf to exclude
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_probs = F.softmax(attn_scores, dim=-1)  # (B, H, L, L)
        attn_probs = self.dropout(attn_probs)        # dropout on the attention map

        # 6) Compute final attention outputs
        #    (B, H, L, L) x (B, H, L, d_k) = (B, H, L, d_k)
        context = torch.matmul(attn_probs, V)

        # 7) Permute back to (B, L, H, d_k), then flatten
        context = context.permute(0, 2, 1, 3).contiguous()  # (B, L, H, d_k)
        context = context.view(B, L, self.d_model)          # (B, L, d_model)

        # 8) Final linear to recombine all heads
        out = self.W_out(context)                           # (B, L, d_model)

        return out


In many cases, you *can* technically do without calling `.contiguous()`, especially if nothing else in your code relies on having a strictly contiguous layout. However, **the safest practice** when you perform a `permute()` and then immediately call `.view()` (or `reshape()`) is to use `.contiguous()` first. 

Here’s why:

- **`permute()`** reorders tensor dimensions **without** rearranging them in memory. This often results in a “non-contiguous” tensor (i.e. the data no longer has a simple, row-major memory layout).
- **`view()`** (unlike `reshape()`) strictly requires that the underlying tensor storage is contiguous. If it’s not, `view()` will throw an error (unless you use `reshape()`, which will try to handle it but may still internally call `.contiguous()` anyway).

By calling `.contiguous()`, you ensure the tensor is laid out in memory in a contiguous block in the new order of dimensions, making the subsequent `view()` safe and predictable. 

So while it may *sometimes* work without `.contiguous()`, the recommended approach is to do:

```python
context = context.permute(0, 2, 1, 3).contiguous()
context = context.view(B, L, self.d_model)
```

to avoid any surprises if later operations demand a contiguous layout or if PyTorch changes behavior in future versions.

In [None]:
ddddmultiHeadAttention = MultiHeadAttention(d_model=12, num_heads=3)

In [41]:
multiHeadAttention = MultiHeadAttention(d_model=12, num_heads=3)

In [42]:
multiHeadAttention(input)

tensor([[[ 7.4925e-02, -8.2597e-02,  1.5392e-01,  2.9184e-01, -1.6251e-01,
          -3.3844e-01,  2.2570e-01,  2.2068e-01, -5.2972e-01,  2.3687e-02,
          -3.6595e-02,  2.0736e-01],
         [ 8.9961e-03, -3.9324e-03,  5.6516e-02,  2.2119e-01, -1.8616e-01,
          -3.8412e-01,  2.3201e-01,  1.5972e-01, -4.4214e-01,  4.6280e-02,
          -1.3124e-01,  1.8051e-01],
         [ 2.3962e-02, -5.5031e-02,  1.2626e-01,  3.0347e-01, -1.5344e-01,
          -3.8098e-01,  2.4902e-01,  2.3129e-01, -5.3142e-01,  3.2017e-02,
          -7.9054e-02,  2.1008e-01],
         [-3.9417e-03, -3.0084e-02,  9.9333e-02,  2.4949e-01, -1.3996e-01,
          -4.1857e-01,  1.7047e-01,  1.7855e-01, -4.5039e-01,  3.2127e-02,
          -1.1767e-01,  1.9613e-01],
         [ 3.0660e-02, -4.4391e-02,  1.2024e-01,  3.2198e-01, -1.6135e-01,
          -4.1303e-01,  1.7937e-01,  2.2037e-01, -4.8638e-01,  3.0240e-02,
          -1.1032e-01,  2.1702e-01]],

        [[ 1.4514e-01, -1.5398e-01,  1.5875e-01,  2.3907e-01, -