# Multi-Head attention with weight splits

In [60]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()

    assert (d_out % num_heads) == 0, "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    # calculate individual head dimension according to d_out and no. of heads present
    self.head_dim = d_out // num_heads

    # random key,query,value initialization with d_in and d_out)
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_in, d_out) # Linear layer to combine head outputs
    self.dropout = nn.Dropout(dropout)

    self.register_buffer(
        "mask",
        torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape # initialize (Batch, token_size, input_dimension)

    # keys, values queries (random of d_in,d_out(dimensions) multiplied with inputs)
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    # convert of each head i.e d_out --> num_heads and head_dimension
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)

    # Group matrices by num_heads for parallel computation.

    #(b,num_tokens,num_heads,head_dim) --> (b, num_heads, num_tokens, head_dim)
    # (1,3,2,3) --> (1,2,3,3) (The positions 1 and 2 will be transposed)
    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)

    # now for each query we will do matmul with keys.
    # and for that we need to transpose the postion 2 and 3 of keys.
    # (b,num_heads,num_tokens,head_dim) * (b, num_heads, head_dim, num_tokens)
    #                                    |
    #                    (b,num_heads,num_tokens,num_tokens)
    attn_scores = queries @ keys.transpose(2,3)

    # masking
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf)

    # softmax with Sqrt of head_dim and dropout
    attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    # calulate context vector with d_out as dimension preserved

    # (b,num_heads,num_tokens,num_tokens) * (b,num_heads,num_tokens,head_dim)
    #                                     |
    #                     (b,num_heads,num_tokens,head_dim)
    #                                     | (1,2) transpose
    #                     (b,num_tokens,num_heads,head_dim)
    context_vector = (attn_weights @ values).transpose(1,2)
    # now we can merge num_heads and head_dim easily to d_out.
    # we merge the num_heads and head_dim into single row giving d_out dimension.
    # (b,num_tokens,num_heads,head_dim) --> (b,num_tokens,d_out)
    # contiguous ensures that after reshaping the values stay in same block of memory.
    context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
    context_vector = self.out_proj(context_vector)

    return context_vector

In [66]:
# example test

torch.manual_seed(123)

# define input of 3 row and 6 columns
inputs = torch.tensor(
    [[0.43,0.15,0.89,0.55,0.87,0.66],
     [0.57,0.85,0.64,0.22,0.58,0.33],
     [0.77,0.25,0.10,0.05,0.80,0.55]]
)

# copy same input and stack on top for 2 batch
batch = torch.stack((inputs,inputs), dim=0)
print(batch.shape)

batch_size, context_length, d_in = batch.shape
d_out = 6

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vec = mha(batch)

print(context_vec)
print("Context vector shape: ", context_vec.shape)

torch.Size([2, 3, 6])
tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]]],
       grad_fn=<ViewBackward0>)
Context vector shape:  torch.Size([2, 3, 6])
