In [51]:
class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out,context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    b, num_tokens, d_in = x.shape # New batch of dimension b
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.transpose(1,2)
    attn_scores = attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)
    context_vector = attn_weights @ values

    return context_vector

In [52]:
print(d_in)

3


In [53]:
print(d_out)

2


In [54]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, dropout=0.0)
context_vecs = ca(batch)
print(context_vecs.shape)

torch.Size([2, 6, 2])


In [55]:
print(context_vecs)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


# Multi-Head Attention

In multi head attention, we run each heads in parallel and combine them for final output context vector.

In causal attention, we used one key, query and value (i.e. single head)
If we use multiple key,query,value for same inputs and combine all the output context vectors from each to give final context vector, then it is known as multi head attention.


In [58]:
# when we use multiple instances of causal attention, we create multi head attention
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, num_heads, context_length, dropout, qkv_bias=False):
    super().__init__()
    # create instance of causal attention class as per num_heads
    self.heads = nn.ModuleList(
        [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
    )

  def forward(self, x):
    # concat all the instance
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [59]:
inputs = torch.tensor(
    [[0.43,0.15,0.89], # Your (x^1)
     [0.55,0.87,0.66], # journey (x^2)
     [0.57,0.85,0.64], # starts (x^3)
     [0.22,0.58,0.33], # with (x^4)
     [0.77,0.25,0.10], # one (x^6)
     [0.05,0.80,0.55]] # step (x^7)
    )
batch = torch.stack((inputs,inputs), dim=0)
print(batch.shape)


torch.Size([2, 6, 3])


In [62]:
# suppose we have 2 num_heads and each head gives output as 2 dimensional context,
# then final combination of 2 and 2 will produce 4 dimensional context matrix.

context_length = batch.shape[1] # no. of tokens = 6
d_in, d_out = 3,2
mha = MultiHeadAttentionWrapper(d_in, d_out, num_heads=2, context_length=context_length, dropout=0.0)
context_vecs = mha(batch)
print(context_vecs.shape)
print(context_vecs)

torch.Size([2, 6, 4])
tensor([[[-0.1471,  0.4106,  0.4675, -0.2793],
         [-0.2493,  0.3548,  0.4651, -0.0590],
         [-0.2782,  0.3323,  0.4578,  0.0089],
         [-0.2636,  0.2932,  0.4108,  0.0479],
         [-0.2197,  0.2186,  0.3167,  0.0217],
         [-0.2420,  0.2433,  0.3495,  0.0638]],

        [[-0.1471,  0.4106,  0.4675, -0.2793],
         [-0.2493,  0.3548,  0.4651, -0.0590],
         [-0.2782,  0.3323,  0.4578,  0.0089],
         [-0.2636,  0.2932,  0.4108,  0.0479],
         [-0.2197,  0.2186,  0.3167,  0.0217],
         [-0.2420,  0.2433,  0.3495,  0.0638]]], grad_fn=<CatBackward0>)


Now, the problem here is we are calculating each instance one by one and combining, which is highly inefficient. We can solve this with parallel computation of heads.

Computing simultaneously the output of all heads by matrix multiplication.