In [5]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

In [4]:
batch = torch.stack((inputs,inputs), dim=0)
print(batch.shape)
#print(batch)

torch.Size([2, 6, 3])


In [7]:
class CausalAttention(nn.Module):
    def __init__(self,d_in, d_out,context_length,dropout,qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask',
                             torch.triu(torch.ones(context_length,context_length), diagonal=1)
                             )
    def forward(self, x):
        b, num_token, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_token,:num_token],-torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)        
        context_vec = attn_weights @ values
        return context_vec

In [8]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = inputs.shape[1]
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:",context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


In [10]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                d_in, d_out,context_length, dropout, qkv_bias
            )
            for _ in range(num_heads)]
        )
    
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [11]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out=3,2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

In [12]:
print(context_vecs)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


In [15]:
mha.heads[0]

CausalAttention(
  (W_query): Linear(in_features=3, out_features=2, bias=False)
  (W_key): Linear(in_features=3, out_features=2, bias=False)
  (W_value): Linear(in_features=3, out_features=2, bias=False)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [51]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), "d_out must be divisible by number of heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias= qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b,num_tokens, self.num_heads,self.head_dim)
        values = values.view(b,num_tokens, self.num_heads,self.head_dim)
        queries = queries.view(b,num_tokens, self.num_heads,self.head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)

        context_vec = self.out_proj(context_vec)
        return context_vec

        

In [52]:
torch.manual_seed(123)
batc_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print(context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 2])


### How the trick works ?

In [37]:
my_heads = 2
my_tokens =5
my_head_dim =3

In [38]:
a = torch.tensor([
    [1,1,1,2,2,2],
    [1,1,1,2,2,2],
    [1,1,1,2,2,2]
])

In [41]:
b=torch.tensor([
    [1,1,1],
    [1,1,1],
    [1,1,1],
    [1,1,1],
    [1,1,1]
])

In [42]:
c=b @ a

In [46]:
c

tensor([[3, 3, 3, 6, 6, 6],
        [3, 3, 3, 6, 6, 6],
        [3, 3, 3, 6, 6, 6],
        [3, 3, 3, 6, 6, 6],
        [3, 3, 3, 6, 6, 6]])

In [45]:
c.view(my_tokens,my_heads, my_head_dim).transpose(0,1)

tensor([[[3, 3, 3],
         [3, 3, 3],
         [3, 3, 3],
         [3, 3, 3],
         [3, 3, 3]],

        [[6, 6, 6],
         [6, 6, 6],
         [6, 6, 6],
         [6, 6, 6],
         [6, 6, 6]]])