In [53]:
import torch
import torch.nn as nn
import torch.functional as F

In [None]:
# Coding the self attention block
'''
Input: input embedding vector
Parameters: d_in (number of input tokens), d_out (dimension of each input token), context_length (number of input tokens used to predict output)
Output: context vector
What happens: we create three weight matrices called query, key and value. We multiply
these weight matrices with the input embedding vector (input + position).
We then get attn scores (query*key.T) -> attn weights (scaling + softmax) -> context vector (attn weights*value)

'''

class SelfAttention(nn.Module):
  def __init__(self,d_in,d_out,qkv_bias=False):
    super().__init__()
    self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)


  def forward(self,X):
    queries = self.W_query(X) # shape (b,num_tokens,d_out)
    keys = self.W_key(X) # shape (b,num_tokens,d_out)
    values = self.W_value(X) # shape (b,num_tokens,d_out)


    attn_scores = (queries@keys.T)/keys.shape[-1]**0.5 # normalize by sqrt(d_out of keys) because it helps reduce variance of softmax
    attn_weights = torch.softmax(attn_scores,dim=1) #shape: (num_tokens,num_tokens)
    context_vector = attn_weights@values # shape: (b,num_tokens,d_out)

    return context_vector  # every row of the context vector corresponds to the context for that particular token

In [55]:
# Applying attention to an input example
'''
Input shape: [batch_size=2, seq_len=3, d_in=4]
'''

torch.manual_seed(34234)
print("Input Sample:\n")
X = torch.randn((5,3))
print(X, X.shape)
print("\n")
print("Batched Input Sample:\n")
batch = torch.stack((X,X),dim=0)
print(batch,batch.shape)

Input Sample:

tensor([[-1.0915,  0.7671,  1.3740],
        [ 1.2651,  1.5660,  0.5978],
        [-0.2918, -1.0493, -0.0675],
        [ 1.1399,  0.5300,  1.0625],
        [ 0.9558, -0.4200, -2.1112]]) torch.Size([5, 3])


Batched Input Sample:

tensor([[[-1.0915,  0.7671,  1.3740],
         [ 1.2651,  1.5660,  0.5978],
         [-0.2918, -1.0493, -0.0675],
         [ 1.1399,  0.5300,  1.0625],
         [ 0.9558, -0.4200, -2.1112]],

        [[-1.0915,  0.7671,  1.3740],
         [ 1.2651,  1.5660,  0.5978],
         [-0.2918, -1.0493, -0.0675],
         [ 1.1399,  0.5300,  1.0625],
         [ 0.9558, -0.4200, -2.1112]]]) torch.Size([2, 5, 3])


In [56]:
d_in = batch.shape[-1]
d_out = batch.shape[-1]
attn = SelfAttention(d_in, d_out)
context_vector = attn(X)

print("Context vector:\n",context_vector,context_vector.shape)

Context vector:
 tensor([[ 0.1478, -0.0381, -0.1513],
        [ 0.1391,  0.0445, -0.0116],
        [ 0.1680,  0.0959,  0.1104],
        [ 0.1404,  0.0836,  0.0815],
        [ 0.1921,  0.1566,  0.1930]], grad_fn=<MmBackward0>) torch.Size([5, 3])


In [None]:
# Coding the casual attention/ masked attention block

'''
Input: input embedding vector
Parameters: d_in (number of input tokens), d_out (dimension of each input token), context_length (number of input tokens used to predict output)
Output: context vector
What happens: For any given output the inputs are the tokens that come before it.

ex - hi, how are you?
  hi -> how
  hi how -> are
  hi how are -> you
  hi how are you => ?

we do not have access to future tokens at all.
The goal is to restrict the model to only consider the previous and current inputs in the sequence
for a given token.
Mask out all the tokens in the upper triangular matrix
attention scores matrix:
[
  [a,-inf,-inf],
  [a,b,-inf],
  [a,b,c]
]
We add '-inf' as the mask because when we use softmax all the mask values become '0' (e^-inf).

Steps:

attention scores -> upper triangular infinity mask -> softmax

'''

class CausalAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
    super().__init__()
    self.d_out = d_out

    self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)

    self.dropout = nn.Dropout(dropout)
    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # causal attention mask which makes all elements in upper triangle '1'
    self.register_buffer('mask', mask)


  def forward(self,X):
    b,num_tokens,d_in = X.shape # new batch dimesnion b

    queries = self.W_query(X) # shape (b,num_tokens,d_out)
    keys = self.W_key(X) # shape (b,num_tokens,d_out)
    values = self.W_value(X) # shape (b,num_tokens,d_out)

    attn_scores = queries@keys.transpose(1,2) # initial keys shape: (b,num_tokens,d_out) after transposing last two dim, new shape: (b,d_out,num_tokens), attn_scores.shape: (b, num_tokens, num_tokens)
    attn_scores.masked_fill_( # _ ops are in-place
        self.mask.bool()[:num_tokens,:num_tokens],-torch.inf
    ) # apply mask only for that particular sequence length/num tokens

    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1) # normalize by sqrt(d_out of keys) because it helps reduce variance of softmax
    attn_weights = self.dropout(attn_weights) # dropout layer -  drops %dropout from each layer. Used to prevent overfitting.
    context_vector = attn_weights@values # every row of the context vector corresponds to the context for that particular token

    return context_vector


In [58]:
# Applying causal attention to an input example

torch.manual_seed(34234)
print("Input Sample:\n")
X = torch.randn((5,3))
print(X, X.shape)
print("\n")
print("Batched Input Sample:\n")
batch = torch.stack((X,X),dim=0)
print(batch,batch.shape)

Input Sample:

tensor([[-1.0915,  0.7671,  1.3740],
        [ 1.2651,  1.5660,  0.5978],
        [-0.2918, -1.0493, -0.0675],
        [ 1.1399,  0.5300,  1.0625],
        [ 0.9558, -0.4200, -2.1112]]) torch.Size([5, 3])


Batched Input Sample:

tensor([[[-1.0915,  0.7671,  1.3740],
         [ 1.2651,  1.5660,  0.5978],
         [-0.2918, -1.0493, -0.0675],
         [ 1.1399,  0.5300,  1.0625],
         [ 0.9558, -0.4200, -2.1112]],

        [[-1.0915,  0.7671,  1.3740],
         [ 1.2651,  1.5660,  0.5978],
         [-0.2918, -1.0493, -0.0675],
         [ 1.1399,  0.5300,  1.0625],
         [ 0.9558, -0.4200, -2.1112]]]) torch.Size([2, 5, 3])


In [59]:
batch_size,context_length,d_in = batch.shape
d_out= 4
ca = CausalAttention(d_in,d_out,context_length,0.0)
context_vector = ca(batch)

print("Context Vector:\n", context_vector,context_vector.shape)

Context Vector:
 tensor([[[-1.1218,  0.8028, -0.7146,  0.6664],
         [-0.4846,  0.1710, -0.2459,  0.9051],
         [-0.5040,  0.3261, -0.3307,  0.5079],
         [-0.2419,  0.1548, -0.2198,  0.4937],
         [-0.2158,  0.1187, -0.1569,  0.3810]],

        [[-1.1218,  0.8028, -0.7146,  0.6664],
         [-0.4846,  0.1710, -0.2459,  0.9051],
         [-0.5040,  0.3261, -0.3307,  0.5079],
         [-0.2419,  0.1548, -0.2198,  0.4937],
         [-0.2158,  0.1187, -0.1569,  0.3810]]], grad_fn=<UnsafeViewBackward0>) torch.Size([2, 5, 4])


In [None]:
# Coding the multi head attention block

'''
Input: input embedding vector
Parameters: d_in (number of input tokens), d_out (dimension of each input token)
Output: multiple context vectors
What happens: use more than one attention heads, head_dim = d_out/no of heads. This helps in capturing multiple
perceptions of the same input. We get multiple context vectors from each attention head and finally combine them
to form a single big context vector.

'''

class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False): #included a num_heads parameter
    super().__init__()
    assert (d_out%num_heads==0), \
    "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out//num_heads

    self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.output_proj = nn.Linear(d_out,d_out) # linear layer to combine the outputs from the different attention heads

    self.dropout = nn.Dropout(dropout)
    mask = torch.triu(torch.ones(context_length,context_length),diagonal=1) # causal attention mask which makes all elements in upper triangle '1'
    self.register_buffer('mask',mask)


  def forward(self,X):
    b,num_tokens,d_in = X.shape # new batch dimesnion b

    queries = self.W_query(X) # shape (b,num_tokens,d_out)
    keys = self.W_key(X) # shape (b,num_tokens,d_out)
    values = self.W_value(X) # shape (b,num_tokens,d_out)

    # Implicitly split the matrix by adding a 'num_heads' dimension
    # Unroll last dimension: (b,num_tokens,d_out) -> (b,num_heads,num_heads,head_dim)

    queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)
    keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
    values = values.view(b,num_tokens,self.num_heads,self.head_dim)

    # Transpose: We need to group by the number of heads instead of the number of tokens
    # We do a transopose to achieve this: (b,num_tokens,num_heads,head_dim) -> (b,,num_heads,num_tokens,head_dim)

    queries = queries.transpose(1,2)
    keys = keys.transpose(1,2)
    values = values.transpose(1,2)

    attn_scores = queries@keys.transpose(2,3) # initial keys shape: (b,num_heads,num_tokens,head_dim) after transposing last two dim, new shape: (b,num_heads,head_dim,num_tokens),
    # attn_scores.shape: (b, d_out,num_tokens, num_tokens)
    attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens,:num_tokens],-torch.inf
    ) # apply mask only for that particular sequence length/num tokens

    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1) # normalize by sqrt(d_out of keys) because it helps reduce variance of softmax
    attn_weights = self.dropout(attn_weights) # dropout layer -  drops %dropout from each layer. Used to prevent overfitting.

    context_vector = (attn_weights@values).transpose(1,2) # group by tokens for each head so its easier the merge, shape: (b,num_tokens,num_heads,head_dim)
    # combine heads, self.d_out = num_heads*head_dim
    context_vector = context_vector.contiguous().view(b,num_tokens,self.d_out) #contigous makes the matrix 'C contiguous' meaning that rows are stored next to each other in memory
    # link - 'https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch' - notes on contigous
    context_vector = self.output_proj(context_vector) #optional projection

    return context_vector

In [64]:
# Applying attention to an input example
'''
Input shape: [batch_size=2, seq_len=3, d_in=4]
'''

torch.manual_seed(34234)
print("Input Sample:\n")
X = torch.randn((5,6))
print(X, X.shape)
print("\n")
print("Batched Input Sample:\n")
batch = torch.stack((X,X),dim=0)
print(batch,batch.shape)

Input Sample:

tensor([[ 1.0112e+00, -4.9833e-01, -4.9400e-01, -6.6797e-01, -3.9417e-02,
         -5.2224e-01],
        [-6.7874e-01, -6.8613e-01,  1.1078e+00, -1.5318e-01,  7.3355e-01,
         -6.1861e-01],
        [-1.6816e-01,  4.1421e-04, -4.5069e-01,  8.1781e-01, -2.5656e-02,
          3.9588e-01],
        [ 9.5480e-01, -5.9837e-01,  3.5129e-01,  5.7757e-01,  2.1269e-01,
         -1.1841e+00],
        [ 4.6373e-02, -2.6153e+00, -4.3345e-02, -1.0756e+00, -1.6027e+00,
          3.0792e-01]]) torch.Size([5, 6])


Batched Input Sample:

tensor([[[ 1.0112e+00, -4.9833e-01, -4.9400e-01, -6.6797e-01, -3.9417e-02,
          -5.2224e-01],
         [-6.7874e-01, -6.8613e-01,  1.1078e+00, -1.5318e-01,  7.3355e-01,
          -6.1861e-01],
         [-1.6816e-01,  4.1421e-04, -4.5069e-01,  8.1781e-01, -2.5656e-02,
           3.9588e-01],
         [ 9.5480e-01, -5.9837e-01,  3.5129e-01,  5.7757e-01,  2.1269e-01,
          -1.1841e+00],
         [ 4.6373e-02, -2.6153e+00, -4.3345e-02, -1.0756e+0

In [69]:
batch_shape,context_length,d_in  = batch.shape
d_out = 4
num_heads = 2
mha = MultiHeadAttention(d_in,d_out,context_length,0.0,num_heads)
context_vector = mha(batch)

print("Context Vector:\n", context_vector,context_vector.shape)

Context Vector:
 tensor([[[-0.2299,  0.0491,  0.0877, -0.4266],
         [-0.3895,  0.2767,  0.2569, -0.4146],
         [-0.4020,  0.2992,  0.3061, -0.3834],
         [-0.3664,  0.3511,  0.3741, -0.3293],
         [-0.3619,  0.2452,  0.2486, -0.3982]],

        [[-0.2299,  0.0491,  0.0877, -0.4266],
         [-0.3895,  0.2767,  0.2569, -0.4146],
         [-0.4020,  0.2992,  0.3061, -0.3834],
         [-0.3664,  0.3511,  0.3741, -0.3293],
         [-0.3619,  0.2452,  0.2486, -0.3982]]], grad_fn=<ViewBackward0>) torch.Size([2, 5, 4])
