#### Casual Self Attention

We need to mask the upper triangle of the attention scores/ attention weights to not allow the tokens see the context for the future tokens. There are 2 ways to do it:

* Masking the attention weights after the softmax is applied. We mask the uppper triangle of the attention weights matrix with 0 and then peform another round of normalization.
* The 2nd method is to mask the uper triangle of the attention scores matrix and then apply the soft-max nornmalization technique.

The 1st method is more efficient as it avoids multiple normalisation and uses less computation.

**Efficinent method**

Attention scores --> Upper triangle -ve infinity mask --> softmax 

In [3]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

In [22]:
inputs = torch.tensor(
    [[0.43,0.15,0.89], # Your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55]] # step
)

Defining elements

* A: The second input element
* B: Input embedding size, d_in=3
* C: Output embedding size, d_out=2 

In [23]:
d_in = inputs.shape[1]
print(f"Input shape is {d_in}")
d_out = 2
print(f"Output shape is {d_out}")

# Intialising the weight matrices

"""
requires_grad is set to False to reduce clutter. But if we were to use the weight matrices for training we would set up to be equal to True.
So that it updates the amtricews during model training.
"""
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False) 
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

print(f"Query weights:{W_query}")
print(" ")
print(f"Key weights:{W_key}")
print(" ")
print(f"Value weights:{W_value}")
print(" ")

"""
For GPT like models the input and the output dimensions are usually the same. But for demostration we are using different dimensions
"""
x_2 = inputs[1]
print(f"The tensor value for the second token is {x_2}")


Input shape is 3
Output shape is 2
Query weights:Parameter containing:
tensor([[0.8943, 0.8260],
        [0.4370, 0.6476],
        [0.5942, 0.3844]])
 
Key weights:Parameter containing:
tensor([[0.1322, 0.7668],
        [0.7084, 0.6391],
        [0.6801, 0.3480]])
 
Value weights:Parameter containing:
tensor([[0.7519, 0.4936],
        [0.6580, 0.5615],
        [0.2156, 0.6007]])
 
The tensor value for the second token is tensor([0.5500, 0.8700, 0.6600])


In [24]:
"""
We get a 1x2 dimensional query, key and value vector. Even though our temporary goal is to only compute the one context vector z(2).
We still require key and value vectors for all the input. As this is required for the calculation of the attention weights
with respect to the query q(2).
"""
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(f"The query vector for the 2nd token is {query_2}")
print(f"The key vector for the 2nd token is {key_2}")
print(f"The value vector for the 2nd token is {value_2}")

The query vector for the 2nd token is tensor([1.2643, 1.2714])
The key vector for the 2nd token is tensor([1.1379, 1.2074])
The value vector for the 2nd token is tensor([1.1282, 1.1564])


In [None]:
# Computing the query, key and value vectors

queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

# We have projected the 6 input tokens from a 3D space onto a 2D embedding space.
print(f"Shape of queries matrix: {queries.shape}")
print(f"Shape of keys matrix: {keys.shape}")
print(f"Shape of values matrix: {values.shape}")
print(" ")

# Computing the attention score for the 2nd token
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(f"Attention score between the 2nd token and the 2nd token: {attn_score_22}")
print(" ")

# Generalising the computation to get all attention scores by matrix multiplication for the 2nd token
attn_score_2 = query_2 @ keys.T #all attention scores for the 2nd token(query)
print(f"Attention score for the entire 2nd token: {attn_score_2}")
print(" ")

# Entire attention score matrix
attn_score = queries @ keys.T
print(f"Entire attention score matrix: {attn_score}")

Shape of queries matrix: torch.Size([6, 2])
Shape of keys matrix: torch.Size([6, 2])
Shape of values matrix: torch.Size([6, 2])
 
Attention score between the 2nd token and the 2nd token: 2.9736814498901367
 
Attention score for the entire 2nd token: tensor([1.9062, 2.9737, 2.9363, 1.6717, 1.4366, 2.1398])
 
Entire attention score matrix: tensor([[1.3363, 2.0731, 2.0450, 1.1701, 0.9632, 1.5161],
        [1.9062, 2.9737, 2.9363, 1.6717, 1.4366, 2.1398],
        [1.9011, 2.9656, 2.9283, 1.6672, 1.4325, 2.1342],
        [0.9997, 1.5615, 1.5423, 0.8770, 0.7613, 1.1193],
        [1.2737, 1.9853, 1.9601, 1.1167, 0.9539, 1.4319],
        [1.1209, 1.7513, 1.7297, 0.9834, 0.8552, 1.2544]])


The next step is to calculate the attention weights by scaling the attention scores and performing a softmax operation. For causal attention we need to mask the upper triangle with -ve infinity. We can resue the class SelfAttention_v2 from multi_head_attention.ipynb notebook

In [7]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Yopur
     [0.55, 0.87, 0.66], # journey
     [0.57, 0.85, 0.64], # starts
     [0.22, 0.58, 0.33], # with
     [0.77, 0.25, 0.10], # one
     [0.05, 0.80, 0.55]] # step
)

d_in = 3
d_out = 2

In [8]:
class SelfAttention_v2(nn.Module):
    
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x) 

        attn_scores = queries @ keys.T
        attn_weight = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)

        context_vec = attn_weight @ values
        return context_vec
        
sa_v2 = SelfAttention_v2(d_in, d_out)

In [None]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = 1 )
print(attn_weights)

tensor([[0.1600, 0.1734, 0.1732, 0.1638, 0.1629, 0.1666],
        [0.1617, 0.1700, 0.1702, 0.1648, 0.1695, 0.1638],
        [0.1620, 0.1697, 0.1699, 0.1649, 0.1697, 0.1638],
        [0.1638, 0.1686, 0.1686, 0.1656, 0.1683, 0.1651],
        [0.1681, 0.1634, 0.1637, 0.1675, 0.1736, 0.1636],
        [0.1611, 0.1717, 0.1716, 0.1644, 0.1658, 0.1655]],
       grad_fn=<SoftmaxBackward0>)


In [18]:
# 1st Method --> Updating attention weights above teh diagonal to zero and then normalising 

# We can use PyTorch tril function to create a mask where the values above the diagonal are zero
context_length = attn_scores.shape[0]
print(torch.ones(context_length, context_length))
print(" ")

mask_simple = torch.tril(torch.ones(context_length, context_length)) # Masking upper diagonal with zero
print(mask_simple)
print(" ")

# Multiplying the masked matrix with the attention weights to zsero out the upper diagonal values.
masked_simple = attn_weights * mask_simple
print(masked_simple)
print(" ")

# The elements above the diagonal are zeroed out but needs to be normalised
row_sums = masked_simple.sum(dim =1, keepdim = True)
masked_simple_norm = masked_simple/row_sums
print(masked_simple_norm)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
 
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
 
tensor([[0.1600, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1617, 0.1700, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1620, 0.1697, 0.1699, 0.0000, 0.0000, 0.0000],
        [0.1638, 0.1686, 0.1686, 0.1656, 0.0000, 0.0000],
        [0.1681, 0.1634, 0.1637, 0.1675, 0.1736, 0.0000],
        [0.1611, 0.1717, 0.1716, 0.1644, 0.1658, 0.1655]],
       grad_fn=<MulBackward0>)
 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4875, 0.5125, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3230, 0.3384, 0.3387, 0.0000, 0.0000, 0.0000],
        [0.2457, 0.2528, 0.253

In [None]:
# 2nd Method --> Updating attention scores above the diagonal to -ve infinity and then applying scaling normalising(softmax) to get attention scores
print(attn_scores)  
print(" ")

mask = torch.triu(torch.ones(context_length,context_length), diagonal = 1 )
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
print(" ")

# applying softmax to the masked matrix, changes the -ve infinity to 0s and sum of every row = 1
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim = 1)
print(attn_weights) 

# Both the methods give us the same answer. But the 2nd method is more efficient than the 1st one.

tensor([[ 0.0817,  0.1954,  0.1936,  0.1148,  0.1072,  0.1386],
        [ 0.0263,  0.0971,  0.0983,  0.0533,  0.0927,  0.0449],
        [ 0.0223,  0.0882,  0.0896,  0.0481,  0.0886,  0.0382],
        [ 0.0149,  0.0551,  0.0559,  0.0303,  0.0528,  0.0254],
        [-0.0556, -0.0959, -0.0928, -0.0604, -0.0101, -0.0940],
        [ 0.0545,  0.1446,  0.1441,  0.0834,  0.0959,  0.0925]],
       grad_fn=<MmBackward0>)
 
tensor([[ 0.0817,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0263,  0.0971,    -inf,    -inf,    -inf,    -inf],
        [ 0.0223,  0.0882,  0.0896,    -inf,    -inf,    -inf],
        [ 0.0149,  0.0551,  0.0559,  0.0303,    -inf,    -inf],
        [-0.0556, -0.0959, -0.0928, -0.0604, -0.0101,    -inf],
        [ 0.0545,  0.1446,  0.1441,  0.0834,  0.0959,  0.0925]],
       grad_fn=<MaskedFillBackward0>)
 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4875, 0.5125, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3230, 0.3384, 0.3387, 0.0000, 0