In [1]:
import torch
import torch.nn as nn
inputs = torch.tensor([
    [0.43, 0.15, 0.891],  # "Your"
    [0.55, 0.87, 0.66],   # "journey"
    [0.57, 0.85, 0.641],  # "starts"
    [0.22, 0.58, 0.331], # "with"
    [0.77, 0.25, 0.101],  # "one"
    [0.05, 0.80, 0.551]   # "step"
])

print(inputs.shape)

input_batch = torch.stack((inputs,inputs),dim=0)
input_batch

torch.Size([6, 3])


tensor([[[0.4300, 0.1500, 0.8910],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6410],
         [0.2200, 0.5800, 0.3310],
         [0.7700, 0.2500, 0.1010],
         [0.0500, 0.8000, 0.5510]],

        [[0.4300, 0.1500, 0.8910],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6410],
         [0.2200, 0.5800, 0.3310],
         [0.7700, 0.2500, 0.1010],
         [0.0500, 0.8000, 0.5510]]])

In [2]:
class MaskedSelfAttention(nn.Module):
  def __init__(self,in_dim,out_dim,context_len,bias=False,dropout=0.25):
    super().__init__()
    self.d_out = out_dim
    self.W_K = nn.Linear(in_dim,out_dim,bias= False)
    self.W_Q = nn.Linear(in_dim,out_dim,bias= False)
    self.W_V = nn.Linear(in_dim,out_dim,bias= False)
    self.dropout=nn.Dropout(dropout)
    self.register_buffer('mask',torch.triu(torch.ones(context_len,context_len),diagonal=1))
  def forward(self,x):
    b, num_tokens , d_in = x.shape
    keys = self.W_K(x)
    queries = self.W_Q(x)
    values = self.W_V(x)
    attention_scores = queries @ keys.transpose(1,2)# here if shape is (2,6,2) , it will swap 1st index dimension and 2nd index dimension.
    attention_scores.masked_fill_(
        self.mask.bool()[:num_tokens,:num_tokens],-torch.inf
    )
    attention_weights  = torch.softmax(attention_scores/keys.shape[-1]**0.5,dim=-1)
    attention_weights =self.dropout(attention_weights)
    return attention_weights @ values

In [3]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self,d_in,d_out,context_len,dropout,num_heads,qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList([
        MaskedSelfAttention(d_in,d_out,context_len,dropout=dropout,bias=qkv_bias) for _ in range(num_heads)
    ])
  def forward(self,input_batch):
    return torch.cat([head(input_batch) for head in self.heads],dim=-1
                     )


In [4]:
torch.manual_seed(123)

d_in , d_out = 3,2 # vector embedding dimension , key matrix dimensioncla
multi_attention = MultiHeadAttentionWrapper(d_in,d_out,6,0,2)


In [5]:
multi_attention(input_batch).shape

torch.Size([2, 6, 4])

In [6]:
multi_attention(input_batch)

tensor([[[-0.4521,  0.2220,  0.4774,  0.1063],
         [-0.5791,  0.0195,  0.5771,  0.3018],
         [-0.6227, -0.0510,  0.6102,  0.3660],
         [-0.5671, -0.0790,  0.5471,  0.3513],
         [-0.5503, -0.0916,  0.5337,  0.3406],
         [-0.5309, -0.1039,  0.5074,  0.3425]],

        [[-0.4521,  0.2220,  0.4774,  0.1063],
         [-0.5791,  0.0195,  0.5771,  0.3018],
         [-0.6227, -0.0510,  0.6102,  0.3660],
         [-0.5671, -0.0790,  0.5471,  0.3513],
         [-0.5503, -0.0916,  0.5337,  0.3406],
         [-0.5309, -0.1039,  0.5074,  0.3425]]], grad_fn=<CatBackward0>)

###Multihead Attention With weight splits

In [7]:
torch.manual_seed(123)

# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]  # Row 3
)
inputs.shape

torch.Size([3, 6])

In [8]:
input_batch = torch.stack((inputs,inputs),dim=0)
b,n_tokens,_ = input_batch.shape

In [9]:
d_out , num_heads =6,2
#initialize weight matrices mxn m should be same as token embedding dimension and n is d_out
query_weights = nn.Linear(input_batch.shape[-1],d_out)
key_weights = nn.Linear(input_batch.shape[-1],d_out)
value_weights = nn.Linear(input_batch.shape[-1],d_out)

In [10]:
keys = key_weights(input_batch)
values = value_weights(input_batch)
queries = query_weights(input_batch)

In [11]:
keys

tensor([[[-0.0169,  0.8474, -0.1203, -0.5320,  0.2600, -0.1741],
         [ 0.2498,  0.9276, -0.3431, -0.0579,  0.0654, -0.4072],
         [ 0.0256,  0.8131, -0.1163, -0.0246,  0.3894, -0.4116]],

        [[-0.0169,  0.8474, -0.1203, -0.5320,  0.2600, -0.1741],
         [ 0.2498,  0.9276, -0.3431, -0.0579,  0.0654, -0.4072],
         [ 0.0256,  0.8131, -0.1163, -0.0246,  0.3894, -0.4116]]],
       grad_fn=<ViewBackward0>)

In [12]:
head_dim = int((d_out/num_heads))
n_heads = 2
#introduces
keys = keys.view(b,n_tokens,n_heads,head_dim)
values = values.view(b,n_tokens,n_heads,head_dim)
queries = queries.view(b,n_tokens,n_heads,head_dim)
print(keys,"\n",keys.shape)

tensor([[[[-0.0169,  0.8474, -0.1203],
          [-0.5320,  0.2600, -0.1741]],

         [[ 0.2498,  0.9276, -0.3431],
          [-0.0579,  0.0654, -0.4072]],

         [[ 0.0256,  0.8131, -0.1163],
          [-0.0246,  0.3894, -0.4116]]],


        [[[-0.0169,  0.8474, -0.1203],
          [-0.5320,  0.2600, -0.1741]],

         [[ 0.2498,  0.9276, -0.3431],
          [-0.0579,  0.0654, -0.4072]],

         [[ 0.0256,  0.8131, -0.1163],
          [-0.0246,  0.3894, -0.4116]]]], grad_fn=<ViewBackward0>) 
 torch.Size([2, 3, 2, 3])


In [13]:
#group matrices by number of heads which is at index 2  torch.Size([2, 3, 2, 3])
keys = keys.transpose(1,2)
values = values.transpose(1,2)
queries = queries.transpose(1,2)

In [14]:
print(keys,"\n",keys.shape)

tensor([[[[-0.0169,  0.8474, -0.1203],
          [ 0.2498,  0.9276, -0.3431],
          [ 0.0256,  0.8131, -0.1163]],

         [[-0.5320,  0.2600, -0.1741],
          [-0.0579,  0.0654, -0.4072],
          [-0.0246,  0.3894, -0.4116]]],


        [[[-0.0169,  0.8474, -0.1203],
          [ 0.2498,  0.9276, -0.3431],
          [ 0.0256,  0.8131, -0.1163]],

         [[-0.5320,  0.2600, -0.1741],
          [-0.0579,  0.0654, -0.4072],
          [-0.0246,  0.3894, -0.4116]]]], grad_fn=<TransposeBackward0>) 
 torch.Size([2, 2, 3, 3])


In [15]:
#compute attention scores
attention_scores = queries @ keys.transpose(-1,-2)
attention_scores

tensor([[[[-0.2827, -0.3169, -0.2963],
          [-0.4942, -0.5187, -0.4996],
          [-0.3080, -0.3412, -0.3198]],

         [[-0.4602, -0.2366, -0.2496],
          [-0.3532, -0.2252, -0.2302],
          [-0.2715, -0.1768, -0.2114]]],


        [[[-0.2827, -0.3169, -0.2963],
          [-0.4942, -0.5187, -0.4996],
          [-0.3080, -0.3412, -0.3198]],

         [[-0.4602, -0.2366, -0.2496],
          [-0.3532, -0.2252, -0.2302],
          [-0.2715, -0.1768, -0.2114]]]], grad_fn=<UnsafeViewBackward0>)

In [16]:
#finding attention weights , normalizing
scores = attention_scores / (head_dim ** 0.5)

#mask is n_tokensxn_tokens because attention scores arae between token and token
mask = torch.triu(torch.ones(n_tokens, n_tokens), diagonal=1).bool()  # (T, T)
scores = scores.masked_fill(mask.unsqueeze(0), float("-inf"))

attention_weights = torch.softmax(scores, dim=-1)

In [17]:
print(attention_weights,"\n",attention_weights.shape)

tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5035, 0.4965, 0.0000],
          [0.3362, 0.3298, 0.3339]],

         [[1.0000, 0.0000, 0.0000],
          [0.4815, 0.5185, 0.0000],
          [0.3235, 0.3417, 0.3349]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.5035, 0.4965, 0.0000],
          [0.3362, 0.3298, 0.3339]],

         [[1.0000, 0.0000, 0.0000],
          [0.4815, 0.5185, 0.0000],
          [0.3235, 0.3417, 0.3349]]]], grad_fn=<SoftmaxBackward0>) 
 torch.Size([2, 2, 3, 3])


In [18]:
context_vectors = attention_weights @ values
print(context_vectors,"\n",context_vectors.shape)

tensor([[[[ 0.0081,  0.2463,  0.9245],
          [-0.1415,  0.2808,  1.0286],
          [-0.1234,  0.1638,  0.9805]],

         [[-0.0649,  0.7756, -0.1287],
          [-0.1837,  0.7628, -0.1644],
          [-0.1185,  0.7449, -0.0999]]],


        [[[ 0.0081,  0.2463,  0.9245],
          [-0.1415,  0.2808,  1.0286],
          [-0.1234,  0.1638,  0.9805]],

         [[-0.0649,  0.7756, -0.1287],
          [-0.1837,  0.7628, -0.1644],
          [-0.1185,  0.7449, -0.0999]]]], grad_fn=<UnsafeViewBackward0>) 
 torch.Size([2, 2, 3, 3])


In [19]:
#in order to concat two attention heads we have to bring them to gether in shape
# torch.Size([2, 2, 3, 3]) [b,attn_heads,vectorsm]
context_vectors.contiguous().view(b, n_tokens, 6)

tensor([[[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]],

        [[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]]],
       grad_fn=<ViewBackward0>)

In [20]:
context_vectors.reshape(2,3,6)

tensor([[[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]],

        [[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]]],
       grad_fn=<ViewBackward0>)

In [21]:
context_vectors.contiguous().view(b, n_tokens, d_out)

tensor([[[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]],

        [[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]]],
       grad_fn=<ViewBackward0>)

In [22]:
context_vectors.contiguous().view(b,n_tokens,d_out)

tensor([[[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]],

        [[ 0.0081,  0.2463,  0.9245, -0.1415,  0.2808,  1.0286],
         [-0.1234,  0.1638,  0.9805, -0.0649,  0.7756, -0.1287],
         [-0.1837,  0.7628, -0.1644, -0.1185,  0.7449, -0.0999]]],
       grad_fn=<ViewBackward0>)

In [76]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
      super().__init__()
      assert (d_out % num_heads == 0), \
          "d_out must be divisible by num_heads"
      self.d_out = d_out
      self.num_heads = num_heads
      self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
      self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
      self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
      self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
      self.out_proj = nn.Linear(d_out,d_out)
      self.dropout = nn.Dropout(dropout)
      self.register_buffer(
          "mask",
          torch.triu(torch.ones(context_length,context_length),diagonal=1)
      )
    def forward(self,x):
      b , num_tokens, d_in = x.shape
      print(x.shape)
      queries = self.W_query(x).view(b,num_tokens,self.num_heads,self.head_dim)
      keys = self.W_key(x).view(b,num_tokens,self.num_heads,self.head_dim)
      values = self.W_value(x).view(b,num_tokens,self.num_heads,self.head_dim)
      # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
      keys = keys.transpose(1,2)
      values = values.transpose(1,2)
      queries = queries.transpose(1,2)


      attention_scores = queries @ keys.transpose(-1,-2)
      # Original mask truncated to the number of tokens and converted to boolean
      mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
      attention_scores.masked_fill_(mask_bool, -torch.inf)
      attn_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)

      attn_weights = self.dropout(attn_weights)

      # Shape: (b, num_tokens, num_heads, head_dim)
      context_vec = (attn_weights @ values).transpose(1, 2)

      # Combine heads, where self.d_out = self.num_heads * self.head_dim
      context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
      context_vec = self.out_proj(context_vec) # optional projection

      return context_vec


In [77]:
d_in,d_out =6,6
attention = MultiHeadAttention(d_in,d_out,6,0.25,2)

In [78]:
attention.forward(input_batch)

torch.Size([2, 3, 6])


tensor([[[ 0.2690, -0.1768,  0.5210,  0.2621,  0.0503, -0.5156],
         [-0.0447, -0.3097,  0.3205,  0.1767, -0.1432, -0.3263],
         [ 0.0896, -0.1831,  0.4364,  0.2409,  0.0094, -0.4319]],

        [[ 0.2690, -0.1768,  0.5210,  0.2621,  0.0503, -0.5156],
         [ 0.1739, -0.1214,  0.5048,  0.1656,  0.0788, -0.4511],
         [ 0.0756, -0.2017,  0.4310,  0.2203,  0.0047, -0.4249]]],
       grad_fn=<ViewBackward0>)

In [None]:
keys.transpose(-1,-2)

In [43]:
input_batch

tensor([[[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]]])