## Coding attention mechanisms
Types of attention mechanisms -
1. Simplified self-attention
2. Self-attention
3. Causal attention
4. Multi-head attention

## Creating a simple self-attention (Without trainable weights)

In [1]:
import torch

# Embedded input sequence
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # Your (x1)
        [0.55, 0.87, 0.66], # journey (x2)
        [0.57, 0.85, 0.64], # starts (x3)
        [0.22, 0.58, 0.33], # with (x4)
        [0.77, 0.25, 0.10], # one (x5)
        [0.05, 0.80, 0.55] # step (x6)
    ]
)

In [2]:
# Calculating attention scores
query = inputs[1] # x2
attn_scores_2 = torch.empty(inputs.shape[0])

In [3]:
attn_scores_2.shape

torch.Size([6])

In [4]:
query

tensor([0.5500, 0.8700, 0.6600])

In [5]:
for i, x_i in enumerate(inputs):
    print(f"Calculating attention score between query and input {i}: {x_i}")
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

Calculating attention score between query and input 0: tensor([0.4300, 0.1500, 0.8900])
Calculating attention score between query and input 1: tensor([0.5500, 0.8700, 0.6600])
Calculating attention score between query and input 2: tensor([0.5700, 0.8500, 0.6400])
Calculating attention score between query and input 3: tensor([0.2200, 0.5800, 0.3300])
Calculating attention score between query and input 4: tensor([0.7700, 0.2500, 0.1000])
Calculating attention score between query and input 5: tensor([0.0500, 0.8000, 0.5500])
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [6]:
# Calculating attention weights by normalizing attention scores

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print(f"Attention weights: {attn_weights_2_tmp}")
print(f"Sum: {attn_weights_2_tmp.sum()}")

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: 1.0000001192092896


In [7]:
attn_weights_2_tmp.sum()

tensor(1.0000)

In [8]:
# Normalizing attention scores using softmax

def softmax_naive(x):
    exp_x = torch.exp(x)
    return exp_x / exp_x.sum()

In [9]:
attn_weights_2_naive = softmax_naive(attn_scores_2)
print(f"Attention weights (softmax): {attn_weights_2_naive}")
print(f"Sum: {attn_weights_2_naive.sum()}")

Attention weights (softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


In [10]:
# Pytorch softmax implementation
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(f"Attention weights (Pytorch softmax): {attn_weights_2}")
print(f"Sum: {attn_weights_2.sum()}")

Attention weights (Pytorch softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


In [11]:
# Calculating the context vector by multiplying the embedded inputs with the corresponding attention weights

query = inputs[1] # x2
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    print(f"Input sequence {i}: {x_i}")
    print(f"Attention weight for input {i}: {attn_weights_2[i]}")
    context_vec_2 += attn_weights_2[i] * x_i
    print(f"Context vector so far: {context_vec_2}")
print(f"Context vector(x2): {context_vec_2}")

Input sequence 0: tensor([0.4300, 0.1500, 0.8900])
Attention weight for input 0: 0.13854756951332092
Context vector so far: tensor([0.0596, 0.0208, 0.1233])
Input sequence 1: tensor([0.5500, 0.8700, 0.6600])
Attention weight for input 1: 0.2378913015127182
Context vector so far: tensor([0.1904, 0.2277, 0.2803])
Input sequence 2: tensor([0.5700, 0.8500, 0.6400])
Attention weight for input 2: 0.23327402770519257
Context vector so far: tensor([0.3234, 0.4260, 0.4296])
Input sequence 3: tensor([0.2200, 0.5800, 0.3300])
Attention weight for input 3: 0.12399158626794815
Context vector so far: tensor([0.3507, 0.4979, 0.4705])
Input sequence 4: tensor([0.7700, 0.2500, 0.1000])
Attention weight for input 4: 0.10818186402320862
Context vector so far: tensor([0.4340, 0.5250, 0.4813])
Input sequence 5: tensor([0.0500, 0.8000, 0.5500])
Attention weight for input 5: 0.15811361372470856
Context vector so far: tensor([0.4419, 0.6515, 0.5683])
Context vector(x2): tensor([0.4419, 0.6515, 0.5683])


In [12]:
# Computing attention weights for all inputs.
attn_scores = torch.empty((inputs.shape[0], inputs.shape[0]))

In [13]:
attn_scores

tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])

In [14]:
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        print(f"Computing dot product of {x_i} and {x_j}")
        attn_scores[i, j] = torch.dot(x_i, x_j)

Computing dot product of tensor([0.4300, 0.1500, 0.8900]) and tensor([0.4300, 0.1500, 0.8900])
Computing dot product of tensor([0.4300, 0.1500, 0.8900]) and tensor([0.5500, 0.8700, 0.6600])
Computing dot product of tensor([0.4300, 0.1500, 0.8900]) and tensor([0.5700, 0.8500, 0.6400])
Computing dot product of tensor([0.4300, 0.1500, 0.8900]) and tensor([0.2200, 0.5800, 0.3300])
Computing dot product of tensor([0.4300, 0.1500, 0.8900]) and tensor([0.7700, 0.2500, 0.1000])
Computing dot product of tensor([0.4300, 0.1500, 0.8900]) and tensor([0.0500, 0.8000, 0.5500])
Computing dot product of tensor([0.5500, 0.8700, 0.6600]) and tensor([0.4300, 0.1500, 0.8900])
Computing dot product of tensor([0.5500, 0.8700, 0.6600]) and tensor([0.5500, 0.8700, 0.6600])
Computing dot product of tensor([0.5500, 0.8700, 0.6600]) and tensor([0.5700, 0.8500, 0.6400])
Computing dot product of tensor([0.5500, 0.8700, 0.6600]) and tensor([0.2200, 0.5800, 0.3300])
Computing dot product of tensor([0.5500, 0.8700, 0

In [15]:
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [16]:
# For loops are slow. So use matrix multiplication
attn_scores = inputs @ inputs.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [17]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [18]:
inputs.T

tensor([[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
        [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000],
        [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]])

In [19]:
attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [20]:
attn_weights = torch.softmax(attn_scores, dim=1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [21]:
attn_weights_2

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [22]:
# Computing the context vector
context_vec = torch.zeros(inputs.shape)
for i, x_i in enumerate(inputs):
    attn_weight = attn_weights[i]
    input_context_vector = context_vec[i]
    print(f"Input sequence {i}: {x_i}")
    print(f"Attention weights for input {i}: {attn_weight}")
    # Multiply each input vector by its corresponding attention weight
    # and accumulate the result in the context vector
    for j, x_j in enumerate(attn_weight):
        context_vec[i] += x_j * inputs[j]
        if i == 1:
            print(f"Multiplying {x_j} with {x_i}")
            print(f"Conext vector so far: {context_vec[i]}")
print(f"Context vector: {context_vec}")

Input sequence 0: tensor([0.4300, 0.1500, 0.8900])
Attention weights for input 0: tensor([0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452])
Input sequence 1: tensor([0.5500, 0.8700, 0.6600])
Attention weights for input 1: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Multiplying 0.13854756951332092 with tensor([0.5500, 0.8700, 0.6600])
Conext vector so far: tensor([0.0596, 0.0208, 0.1233])
Multiplying 0.2378913015127182 with tensor([0.5500, 0.8700, 0.6600])
Conext vector so far: tensor([0.1904, 0.2277, 0.2803])
Multiplying 0.23327402770519257 with tensor([0.5500, 0.8700, 0.6600])
Conext vector so far: tensor([0.3234, 0.4260, 0.4296])
Multiplying 0.12399158626794815 with tensor([0.5500, 0.8700, 0.6600])
Conext vector so far: tensor([0.3507, 0.4979, 0.4705])
Multiplying 0.10818186402320862 with tensor([0.5500, 0.8700, 0.6600])
Conext vector so far: tensor([0.4340, 0.5250, 0.4813])
Multiplying 0.15811361372470856 with tensor([0.5500, 0.8700, 0.6600])
Conext vector so far: tensor

In [23]:
context_vec

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [24]:
# Again, for loops are slow. So use matrix multiplication
context_vec = attn_weights @ inputs
context_vec

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## Implementing self-attention with trainable weights
This is also called <i>Scaled dot-product attention</i>

In [25]:
x_2 = inputs[1] # The second input element
d_in = inputs.shape[1] # Input embedding size, d= 3
d_out = 2 # Output embedding size, d_out = 2


In [26]:
x_2

tensor([0.5500, 0.8700, 0.6600])

In [27]:
# Initialize weight matrices Wq, Wk, Wv

torch.manual_seed(123)  # For reproducibility
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [28]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [29]:
# Compute query, key, and value vectors for the second input element

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

In [30]:
query_2

tensor([0.4306, 1.4551])

In [31]:
# Compute keys and values for all inputs
keys = inputs @ W_key
values = inputs @ W_value

In [32]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])

In [33]:
values

tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])

In [34]:
# Computing attention scores using the query vector and keys

key_2 = keys[1]  # Using the key for the second input element
attn_score_22 = query_2.dot(key_2)
attn_score_22

tensor(1.8524)

In [35]:
# Get all attention scores for the second input element
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [36]:
# Computing attention weights from attention scores by scaling (dividing by square root of embedding dimension of keys) the scores and using softmax
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [37]:
# Computing context vector for the second input element. This is the weighted sum over the value vectors
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([0.3061, 0.8210])

### Implementing a compact class for self-attention

In [38]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attention_scores = queries @ keys.T # Omega
        d_k = keys.shape[-1]
        attention_weights = torch.softmax(attention_scores / d_k**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector
    
    def assign_custom_weights(self, W_query, W_key, W_value):
        self.W_query = nn.Parameter(W_query)
        self.W_key = nn.Parameter(W_key)
        self.W_value = nn.Parameter(W_value)

In [39]:
torch.manual_seed(123)  # For reproducibility
sa_v1 = SelfAttention_v1(d_in, d_out)
context_vec = sa_v1(inputs)
context_vec

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [40]:
# A Self Attention class using PyTorch's Linear layers
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.T  # Omega
        d_k = keys.shape[-1]
        attention_weights = torch.softmax(attention_scores / d_k**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector
    
    def retrieve_weights(self):
        return self.W_query.weight, self.W_key.weight, self.W_value.weight

In [41]:
torch.manual_seed(789)  # For reproducibility
sa_v2 = SelfAttention_v2(d_in, d_out)
context_vec = sa_v2(inputs)
context_vec

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [42]:
v2_weights = sa_v2.retrieve_weights()
print(f"Query weights: {v2_weights[0]}")
print(f"Key weights: {v2_weights[1]}")
print(f"Value weights: {v2_weights[2]}")

Query weights: Parameter containing:
tensor([[ 0.3161,  0.4568,  0.5118],
        [-0.1683, -0.3379, -0.0918]], requires_grad=True)
Key weights: Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True)
Value weights: Parameter containing:
tensor([[ 0.2526, -0.1415, -0.1962],
        [ 0.5191, -0.0852, -0.2043]], requires_grad=True)


In [43]:
# Since nn.Linear stores weights in transposed form, we need to transpose them to make them compatible with V1
v2_weights_trans = (v2_weights[0].T, v2_weights[1].T, v2_weights[2].T)

# Transfer weights from V2 to V1
sa_v1.assign_custom_weights(*v2_weights_trans)

In [44]:
# Now we can use the V1 class with the weights from V2. Verify that the output is the same from V1 and V2
torch.manual_seed(123)  # For reproducibility
sa_v1(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

## Hiding future words with causal attention (masked attention)

In [45]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T
attn_scores

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)

In [46]:
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

In [47]:
# Create a mask using PyTorch's tril function
context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [48]:
# Multiply attn_weights with the mask to zero-out the values above the diagonal
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [49]:
# Re-normlize the masked attention weights
masked_simple = masked_simple / masked_simple.sum(dim=-1, keepdim=True)
masked_simple

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

In [50]:
# Improved masking based on mathematical property of softmax function
# Creating a mask with 1s above the diagonal and then replacing these 1s with -inf
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # Note that previously we used tril (Not triu) with diagonal arg was 0, which is default value
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [51]:
mask.bool()

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [52]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

In [53]:
# Now, if we apply softmax to the masked attention scores, it will zero-out the values above the diagonal because softmax will treat -inf as 0

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

### Masking additional attention weights with dropout (Useful for reducing overfitting while training LLM)

In [54]:
# Demonstrating dropout
torch.manual_seed(123)  # For reproducibility
dropout = nn.Dropout(p=0.5)  # 50% dropout rate
example_tensor = torch.ones(6, 6)
example_tensor

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [55]:
dropout(example_tensor)

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])

After applying the dropout, randomly 50% elements of the attention weight matrix becomes zero. To compensate this reduction, rest of the active elements are scaled up by a factor of `1 / 0.5 = 2`. This balancing is crucial to maintain the overall balance of attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phase. Note that we don't apply dropout during inference.

In [56]:
# Now lets apply dropout to the attention weights
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.0000, 0.4638, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.3968, 0.3775, 0.3941, 0.0000],
        [0.3869, 0.3327, 0.0000, 0.0000, 0.3331, 0.3058]],
       grad_fn=<MulBackward0>)

## Implementing a compact class for causal attention

In [57]:
# Building intuition of batch inputs
example_batch = torch.stack((inputs, inputs), dim=0)
example_batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [58]:
example_batch.shape

torch.Size([2, 6, 3])

In [59]:
# Implementing causal attention class

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Transpose dimentions 1 and 2, keep the batch dimension intact
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vector = attn_weights @ values
        return context_vector

In [60]:
torch.manual_seed(123)
context_length = example_batch.shape[1]  # Number of tokens in the input sequence
ca = CausalAttention(d_in, d_out, context_length, dropout=0.0)
context_vecs = ca(example_batch)
context_vecs.shape

torch.Size([2, 6, 2])

In [61]:
context_vecs

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)

## Multi-head attention

In [62]:
# A wrapper class to implement multi-head attention which uses multiple instances of causal attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [63]:
example_batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [64]:
example_batch.shape

torch.Size([2, 6, 3])

In [65]:
torch.manual_seed(123)
context_length = example_batch.shape[1]  # Number of tokens in the input sequence
d_in = example_batch.shape[-1]  # Input embedding size
d_out = 2  # Output embedding size
num_heads = 2  # Number of attention heads
mha = MultiHeadAttention(d_in, d_out, num_heads, context_length, dropout=0.0)
context_vecs = mha(example_batch)
context_vecs

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)

In [66]:
context_vecs.shape

torch.Size([2, 6, 4])

In [67]:
# Returning two-dimensional context vectors
torch.manual_seed(123)
d_out = 1  # Output embedding size
mha = MultiHeadAttention(d_in, d_out, num_heads, context_length, dropout=0.0)
context_vecs = mha(example_batch)
context_vecs

tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)

### Implementing multi-head attention with weight splits
The idea is to split the input into multiple heads by reshaping the projected query, key and value tensors and then combine the results form thease heads after computing attention.

In [68]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "Output embedding size must be divisible by number of heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduces the projection dimension to match the desired output dimension
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # Reshape the keys, queries and values to split them into multiple heads
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose from (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute dot product for each head
        attn_scores = queries @ keys.transpose(2, 3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # Masks truncated to the number of tokens in the input sequence

        # Use mask to fill attention scores with -inf
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # Compute attention weights
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Compute context vector for each head
        context_vector = (attn_weights @ values).transpose(1, 2) # Transpose back to (b, num_tokens, num_heads, head_dim)

        # Combine heads where d_out = num_heads * head_dim
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out) # Note: Instead of using contiguous().view(), you can use reshape() as well, but contiguous ensures that the memory layout is correct.

        # Apply an optional linear projection to combine the heads
        context_vector = self.out_proj(context_vector)

        return context_vector

In [69]:
torch.manual_seed(123)
batch_size, context_length, d_in = example_batch.shape
d_out = 2  # Output embedding size
num_heads = 2  # Number of attention heads
mha = MultiHeadAttention(d_in, d_out, num_heads, context_length, dropout=0.0)
context_vecs = mha(example_batch)
context_vecs

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)

In [70]:
example_batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [71]:
example_batch.shape

torch.Size([2, 6, 3])

In [72]:
# Creating a MultiHeadAttention instance for smallest GPT-2 model (117M parameters)
batch_size, context_length, d_in = [12, 1024, 768] # 12 attention heads, 1024 tokens, 768 embedding size.
d_out = 768  # Output embedding size
num_heads = 12  # Number of attention heads
mha = MultiHeadAttention(d_in, d_out, num_heads, context_length, dropout=0.2)

In [73]:
mha

MultiHeadAttention(
  (W_query): Linear(in_features=768, out_features=768, bias=False)
  (W_key): Linear(in_features=768, out_features=768, bias=False)
  (W_value): Linear(in_features=768, out_features=768, bias=False)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

### An alternate Multi-Head Attention with combined weights
The idea is to a single combined weight matrix, QKV, instead of individual weight matrices for Query, Key and Value

In [74]:
class MultiHeadAttentionCombinedQKV(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "Output embedding size must be divisible by number of heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_qkv = nn.Linear(d_in, d_out * 3, bias=qkv_bias)  # Combined weight matrix for Q, K and V
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # (b, num_tokens, d_in) -> (b, num_tokens, d_out * 3)
        qkv = self.W_qkv(x)

        # (b, num_tokens, d_out * 3) -> (b, num_tokens, 3, num_heads, head_dim)
        qkv = qkv.view(b, num_tokens, 3, self.num_heads, self.head_dim)

        # (b, num_tokens, 3, num_heads, head_dim) -> (3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
        queries, keys, values = qkv.unbind(dim=0)

        attn_scores = queries @ keys.transpose(-2, -1)  # (b, num_heads, num_tokens, num_tokens)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # (b, num_heads, num_tokens, num_tokens) -> (b, num_heads, num_tokens, head_dim)
        context_vec = attn_weights @ values

        context_vec = context_vec.transpose(1, 2)  # Transpose back to (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, d_in)
        context_vec = self.out_proj(context_vec)

        return context_vec

In [75]:
torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 8
context_len = 1024
embed_dim = 768
num_heads = 12
embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)

mha_combined_qkv = MultiHeadAttentionCombinedQKV(embed_dim, embed_dim, num_heads, context_len, dropout=0.0)
context_vecs_combined = mha_combined_qkv(embeddings)
context_vecs_combined.shape

torch.Size([8, 1024, 768])