#### Casual Self Attention

We need to mask the upper triangle of the attention scores/ attention weights to not allow the tokens see the context for the future tokens. There are 2 ways to do it:

* Masking the attention weights after the softmax is applied. We mask the uppper triangle of the attention weights matrix with 0 and then peform another round of normalization.
* The 2nd method is to mask the uper triangle of the attention scores matrix and then apply the soft-max nornmalization technique.

The 1st method is more efficient as it avoids multiple normalisation and uses less computation.

**Efficinent method**

Attention scores --> Upper triangle -ve infinity mask --> softmax 

In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

In [2]:
inputs = torch.tensor(
    [[0.43,0.15,0.89], # Your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55]] # step
)

Defining elements

* A: The second input element
* B: Input embedding size, d_in=3
* C: Output embedding size, d_out=2 

In [3]:
d_in = inputs.shape[1]
print(f"Input shape is {d_in}")
d_out = 2
print(f"Output shape is {d_out}")

# Intialising the weight matrices

"""
requires_grad is set to False to reduce clutter. But if we were to use the weight matrices for training we would set up to be equal to True.
So that it updates the amtricews during model training.
"""
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False) 
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

print(f"Query weights:{W_query}")
print(" ")
print(f"Key weights:{W_key}")
print(" ")
print(f"Value weights:{W_value}")
print(" ")

"""
For GPT like models the input and the output dimensions are usually the same. But for demostration we are using different dimensions
"""
x_2 = inputs[1]
print(f"The tensor value for the second token is {x_2}")


Input shape is 3
Output shape is 2
Query weights:Parameter containing:
tensor([[0.0293, 0.7416],
        [0.2832, 0.3228],
        [0.9760, 0.4255]])
 
Key weights:Parameter containing:
tensor([[0.9567, 0.2613],
        [0.8034, 0.3252],
        [0.9757, 0.5019]])
 
Value weights:Parameter containing:
tensor([[0.7378, 0.9607],
        [0.6535, 0.1652],
        [0.6051, 0.2268]])
 
The tensor value for the second token is tensor([0.5500, 0.8700, 0.6600])


In [4]:
"""
We get a 1x2 dimensional query, key and value vector. Even though our temporary goal is to only compute the one context vector z(2).
We still require key and value vectors for all the input. As this is required for the calculation of the attention weights
with respect to the query q(2).
"""
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(f"The query vector for the 2nd token is {query_2}")
print(f"The key vector for the 2nd token is {key_2}")
print(f"The value vector for the 2nd token is {value_2}")

The query vector for the 2nd token is tensor([0.9067, 0.9696])
The key vector for the 2nd token is tensor([1.8691, 0.7578])
The value vector for the 2nd token is tensor([1.3737, 0.8218])


In [5]:
# Computing the query, key and value vectors

queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

# We have projected the 6 input tokens from a 3D space onto a 2D embedding space.
print(f"Shape of queries matrix: {queries.shape}")
print(f"Shape of keys matrix: {keys.shape}")
print(f"Shape of values matrix: {values.shape}")
print(" ")

# Computing the attention score for the 2nd token
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(f"Attention score between the 2nd token and the 2nd token: {attn_score_22}")
print(" ")

# Generalising the computation to get all attention scores by matrix multiplication for the 2nd token
attn_score_2 = query_2 @ keys.T #all attention scores for the 2nd token(query)
print(f"Attention score for the entire 2nd token: {attn_score_2}")
print(" ")

# Entire attention score matrix
attn_score = queries @ keys.T
print(f"Entire attention score matrix: {attn_score}")

Shape of queries matrix: torch.Size([6, 2])
Shape of keys matrix: torch.Size([6, 2])
Shape of values matrix: torch.Size([6, 2])
 
Attention score between the 2nd token and the 2nd token: 2.4293951988220215
 
Attention score for the entire 2nd token: tensor([1.8588, 2.4294, 2.4035, 1.3044, 1.2610, 1.6452])
 
Entire attention score matrix: tensor([[1.7468, 2.2918, 2.2682, 1.2294, 1.2043, 1.5433],
        [1.8588, 2.4294, 2.4035, 1.3044, 1.2610, 1.6452],
        [1.8243, 2.3833, 2.3578, 1.2798, 1.2355, 1.6149],
        [0.9883, 1.2930, 1.2793, 0.6941, 0.6733, 0.8743],
        [0.6894, 0.8831, 0.8721, 0.4765, 0.4286, 0.6157],
        [1.3927, 1.8307, 1.8121, 0.9815, 0.9677, 1.2293]])


The next step is to calculate the attention weights by scaling the attention scores and performing a softmax operation. For causal attention we need to mask the upper triangle with -ve infinity. We can resue the class SelfAttention_v2 from multi_head_attention.ipynb notebook

In [6]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Yopur
     [0.55, 0.87, 0.66], # journey
     [0.57, 0.85, 0.64], # starts
     [0.22, 0.58, 0.33], # with
     [0.77, 0.25, 0.10], # one
     [0.05, 0.80, 0.55]] # step
)

d_in = 3
d_out = 2

In [7]:
class SelfAttention_v2(nn.Module):
    
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x) 

        attn_scores = queries @ keys.T
        attn_weight = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)

        context_vec = attn_weight @ values
        return context_vec
        
sa_v2 = SelfAttention_v2(d_in, d_out)

In [8]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = 1 )
print(attn_weights)

tensor([[0.1794, 0.1561, 0.1568, 0.1682, 0.1792, 0.1603],
        [0.1656, 0.1572, 0.1577, 0.1745, 0.1764, 0.1686],
        [0.1658, 0.1570, 0.1575, 0.1746, 0.1766, 0.1685],
        [0.1636, 0.1630, 0.1631, 0.1711, 0.1701, 0.1692],
        [0.1692, 0.1560, 0.1566, 0.1736, 0.1782, 0.1664],
        [0.1620, 0.1645, 0.1646, 0.1707, 0.1682, 0.1699]],
       grad_fn=<SoftmaxBackward0>)


In [9]:
# 1st Method --> Updating attention weights above teh diagonal to zero and then normalising 

# We can use PyTorch tril function to create a mask where the values above the diagonal are zero
context_length = attn_scores.shape[0]
print(torch.ones(context_length, context_length))
print(" ")

mask_simple = torch.tril(torch.ones(context_length, context_length)) # Masking upper diagonal with zero
print(mask_simple)
print(" ")

# Multiplying the masked matrix with the attention weights to zsero out the upper diagonal values.
masked_simple = attn_weights * mask_simple
print(masked_simple)
print(" ")

# The elements above the diagonal are zeroed out but needs to be normalised
row_sums = masked_simple.sum(dim =1, keepdim = True)
masked_simple_norm = masked_simple/row_sums
print(masked_simple_norm)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
 
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
 
tensor([[0.1794, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1656, 0.1572, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1658, 0.1570, 0.1575, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1630, 0.1631, 0.1711, 0.0000, 0.0000],
        [0.1692, 0.1560, 0.1566, 0.1736, 0.1782, 0.0000],
        [0.1620, 0.1645, 0.1646, 0.1707, 0.1682, 0.1699]],
       grad_fn=<MulBackward0>)
 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5130, 0.4870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3452, 0.3270, 0.3279, 0.0000, 0.0000, 0.0000],
        [0.2475, 0.2466, 0.246

In [10]:
# 2nd Method --> Updating attention scores above the diagonal to -ve infinity and then applying scaling normalising(softmax) to get attention scores
print(attn_scores)  
print(" ")

mask = torch.triu(torch.ones(context_length,context_length), diagonal = 1 )
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
print(" ")

# applying softmax to the masked matrix, changes the -ve infinity to 0s and sum of every row = 1
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim = 1)
print(attn_weights) 

# Both the methods give us the same answer. But the 2nd method is more efficient than the 1st one.

tensor([[-0.0935, -0.2903, -0.2843, -0.1848, -0.0955, -0.2524],
        [-0.2478, -0.3212, -0.3173, -0.1734, -0.1584, -0.2227],
        [-0.2506, -0.3272, -0.3232, -0.1771, -0.1607, -0.2276],
        [-0.1330, -0.1383, -0.1371, -0.0690, -0.0778, -0.0855],
        [-0.2293, -0.3443, -0.3395, -0.1937, -0.1565, -0.2530],
        [-0.1155, -0.0931, -0.0928, -0.0409, -0.0619, -0.0474]],
       grad_fn=<MmBackward0>)
 
tensor([[-0.0935,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.2478, -0.3212,    -inf,    -inf,    -inf,    -inf],
        [-0.2506, -0.3272, -0.3232,    -inf,    -inf,    -inf],
        [-0.1330, -0.1383, -0.1371, -0.0690,    -inf,    -inf],
        [-0.2293, -0.3443, -0.3395, -0.1937, -0.1565,    -inf],
        [-0.1155, -0.0931, -0.0928, -0.0409, -0.0619, -0.0474]],
       grad_fn=<MaskedFillBackward0>)
 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5130, 0.4870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3452, 0.3270, 0.3279, 0.0000, 0

In [11]:
# Masking additional weights with dropout 
"""
Using 50% dropout rate, masking half of the attention weights. Ideally when training the GPT models a lower dropout rate is prefered (0.1 or 0.2).
Applying PyTorch's dropout implementation to a 6x6 tensor consisting of ones
"""

example = torch.ones(6,6)
print(example)
print(" ")

"""The dropout rate would be on an average, all the rows does not necessarily have to have 50% of the length being cut short.
With 0.5 dropout factor the neurons which is not put to zero would be scaled by 1/ 0.5 .The scaling is to maintain the overall balance of the
attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phase.
"""
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
print(f"Dropout example: {dropout(example)}\n")

# Applying dropout to the attention weigt matrix
torch.manual_seed(123)
print(f"Dropout example for the attention weights: {dropout(attn_weights)}")

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
 
Dropout example: tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

Dropout example for the attention weights: tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6903, 0.6539, 0.6558, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4932, 0.4937, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3744, 0.0000, 0.4164, 0.0000, 0.0000],
        [0.0000, 0.3291, 0.3292, 0.3415, 0.3364, 0.0000]],
       grad_fn=<MulBackward0>)


#### Putting all the pieces for a compact casual attention class

In [12]:
# Adding the casual attention and dropout modifications into the SelfAttention Python class.
""""
The code should handle the batches consisting of more than one input. This ensures that the casual attention class would handle the batch outputs
produceed by the data loader. The batch inputs can be created by duplicating the input text example. 2 inputs with 6 tokens and each token has a 
embedding dimension of 3.
"""

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Yopur
     [0.55, 0.87, 0.66], # journey
     [0.57, 0.85, 0.64], # starts
     [0.22, 0.58, 0.33], # with
     [0.77, 0.25, 0.10], # one
     [0.05, 0.80, 0.55]] # step
)

batch = torch.stack((inputs, inputs), dim = 0) 
print(batch.shape) # a 3D tensor of 2 input texts with 6 tokens each, where each token is a 3D embedding vector.

torch.Size([2, 6, 3])


### Creating the Causal Attention Class

Step 1: Add a dropout layer\
Step 2: register_buffer call is added\
Step 3: transpose dimensions 1 and 2, keeping the batch dimension at the first position (0).\
Step 4: In PyTorch, operations with a trailing underscore are performed in-place, avoiding unnecessary memory copies

In [13]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length), diagonal = 1))
        """
        Using register_buffer is useful as buffers are automatically moved to the apprppriate device(CPU or GPU) along with our model, this is
        relevant when training LLMs. Thus we do not need to manually ensure these tensors are on the same device as your model parameters, avoiding
        device mismatch errors. 
        """

    def forward(self, x):
        b, num_tokens, d_in = x.shape # b= batch size, num_tokens = number of tokens and d_in = input dimensions 
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # replacing the upper diagonal with -ve infinity
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] **0.5, dim = -1)
        attn_weights = self.dropout(attn_weights) # introducing the dropout

        context_vec = attn_weights @ values

        return context_vec
    
print(f"Input dimension: {d_in}")
print(f"Output dimension: {d_out}\n")

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(f"Context vector shape: {context_vecs.shape}\n")
print(context_vecs)

Input dimension: 3
Output dimension: 2

Context vector shape: torch.Size([2, 6, 2])

tensor([[[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]],

        [[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


In [18]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length), diagonal = 1))
        """
        Using register_buffer is useful as buffers are automatically moved to the apprppriate device(CPU or GPU) along with our model, this is
        relevant when training LLMs. Thus we do not need to manually ensure these tensors are on the same device as your model parameters, avoiding
        device mismatch errors. 
        """

    def forward(self, x):
        b, num_tokens, d_in = x.shape # b= batch size, num_tokens = number of tokens and d_in = input dimensions 
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # replacing the upper diagonal with -ve infinity
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] **0.5, dim = -1)

        print(f"Attention weights before dropout:\n{attn_weights}")

        attn_weights = self.dropout(attn_weights) # introducing the dropout
        print(f"\nAttention weights after dropout:\n{attn_weights}")

        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, 0.5)
context_vecs = ca(batch)
# print(f"Context vector shape: {context_vecs.shape}\n")
# print(context_vecs)   

Attention weights before dropout:
tensor([[[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
         [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
         [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
         [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
         [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],

        [[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
         [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
         [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
         [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
         [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]]],
       grad_fn=<SoftmaxBackward0>)

Attention weights after dropout:
tensor([[[0.3433, 0.0000, 0.3522, 0.3110, 0.3253, 0.3158],
         [0.3272, 0.0000, 0.3492, 0.0000, 0.3210, 0.3303],
         [0.0000, 0.0000, 0.3492, 0.0000, 0.3211, 0.0000],
        