# Coding Attention Mechanisms

## Attending to different parts of the input with self-attention

### simple self-attention without trainable weights

In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
input_query = inputs[1] ## journey
input_query 

tensor([0.5500, 0.8700, 0.6600])

In [5]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1500, 0.8900])

In [7]:
torch.dot(input_query, input_1)

tensor(0.9544)

In [None]:
torch.matmul(input_query, inputs.T)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

Above tensor, each element shows the similarity or attention score for the `input_query` with the corresponding `inputs`


In [25]:
## the below calculates the attention scores of second word wrt to each word in the sentence.
attention_scores_2 = torch.softmax(torch.matmul(input_query, inputs.T), dim=-1)

In [28]:
attention_scores_2

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

Let's calculate the attention score for each word in the sentence wrt every other word in the sentence


In [19]:
torch.matmul(inputs, inputs.T)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

Every row is the attention score of the input word wrt every other word


In [None]:
## Now, let's normalize attention scores to not let the magnitudes introduces biases
torch.softmax(torch.matmul(inputs, inputs.T), dim=-1)


tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [30]:
context_vector_2 = torch.matmul(attention_scores_2, inputs)
context_vector_2

tensor([0.4419, 0.6515, 0.5683])

`attention_scores_2` corresponds with second element of the attention matrix as it should.

`context_vector_2` is calculated as the weighted sum of the input vectors generally speaking.

### simple self-attention without trainable weights for all tokens

In [33]:
torch.softmax(torch.matmul(inputs, inputs.T), dim=-1)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

Attention score of each token with respect to all other tokens.

Each row corresponds to the attention scores of a single token at that index.

```
[[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452], # your
[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581], # journey
[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565], # starts
[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720], # with
[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295], # one
[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]] # step

```

In [34]:
torch.softmax(torch.matmul(inputs, inputs.T), dim=-1).sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

The above result shows that the normalization works and the sum of each attention vector is 1

In [35]:
attention_scores = inputs @ inputs.T
attention_weights = torch.softmax(attention_scores, dim=-1)
context_vector = attention_weights @ inputs
context_vector

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

The above matrix, each row is the context vector for the corresponding input tokens. 

```
[[0.4421, 0.5931, 0.5790], # Your
[0.4419, 0.6515, 0.5683], # Journey 
[0.4431, 0.6496, 0.5671], # Starts
[0.4304, 0.6298, 0.5510], # with 
[0.4671, 0.5910, 0.5266], # one
[0.4177, 0.6503, 0.5645]] # step
```


### self-attention with trainable weights
Now we will see that the inputs are not used as is, they are multiplied with the trainable weights which yield Q, K and V matrices hence the context vector is not the same dimensionality as the input. 

In [37]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [38]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.randn(d_in,d_out))
W_key = torch.nn.Parameter(torch.randn(d_in,d_out))
W_value = torch.nn.Parameter(torch.randn(d_in,d_out))

In [39]:
W_query

Parameter containing:
tensor([[-0.1115,  0.1204],
        [-0.3696, -0.2404],
        [-1.1969,  0.2093]], requires_grad=True)

In [40]:
query_2 = x_2 @ W_query

In [41]:
query_2

tensor([-1.1729, -0.0048], grad_fn=<SqueezeBackward4>)

In [None]:
keys = inputs @ W_key
keys

tensor([[-0.1823, -0.6888],
        [-0.1142, -0.7676],
        [-0.1443, -0.7728],
        [ 0.0434, -0.3580],
        [-0.6467, -0.6476],
        [ 0.3262, -0.3395]], grad_fn=<MmBackward0>)

In [44]:
values = inputs @ W_value
values


tensor([[ 0.1196, -0.3566],
        [ 0.4107,  0.6274],
        [ 0.4091,  0.6390],
        [ 0.2436,  0.4182],
        [ 0.2653,  0.6668],
        [ 0.2728,  0.3242]], grad_fn=<MmBackward0>)

In [45]:
query = inputs @ W_query
query

tensor([[-1.1686,  0.2019],
        [-1.1729, -0.0048],
        [-1.1438, -0.0018],
        [-0.6339, -0.0439],
        [-0.2979,  0.0535],
        [-0.9596, -0.0712]], grad_fn=<MmBackward0>)

In [50]:
query_2 = query[1]
keys_2 = keys[1]
query_2, keys_2

(tensor([-1.1729, -0.0048], grad_fn=<SelectBackward0>),
 tensor([-0.1142, -0.7676], grad_fn=<SelectBackward0>))

In [55]:
attention_scores_22 = query_2 @ keys_2
attention_scores_22

tensor(0.1376, grad_fn=<DotBackward0>)

In [57]:
attention_scores_2 = query_2 @ keys.T
attention_scores_2

tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809],
       grad_fn=<SqueezeBackward4>)

In [59]:
d_k = keys.shape[1]
torch.softmax(attention_scores_2 / d_k**0.5, dim=-1)

tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117],
       grad_fn=<SoftmaxBackward0>)

In [60]:
torch.softmax(attention_scores_2 / d_k**0.5, dim=-1).sum(dim=-1)


tensor(1.0000, grad_fn=<SumBackward1>)

In [61]:
attention_scores = query @ keys.T / d_k**0.5
attention_weights = torch.softmax(attention_scores, dim=-1)
context_vector = attention_weights @ values
context_vector

tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)

If you see, each element of the above matrix, again corresponds to each word, but the dimensions are different as we did not use the vector representation of the words directly rather their projections on a different plane where their dimensions shrunk (in this case).

### Implementing the self-attention class

In [72]:
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, bias = False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=bias)
        self.W_k = nn.Linear(d_in, d_out, bias=bias)
        self.W_v = nn.Linear(d_in, d_out, bias=bias)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        d_k = K.shape[-1]
        
        attentions_scores = Q @ K.T
        attention_weights = torch.softmax(attentions_scores / d_k**0.5, dim=-1)
        context_vector = attention_weights @ V
        return context_vector


In [73]:
sa_v1 = SelfAttention(d_in, d_out)
sa_v1(inputs)

tensor([[-0.2410,  0.2378],
        [-0.2419,  0.2406],
        [-0.2418,  0.2403],
        [-0.2409,  0.2409],
        [-0.2392,  0.2355],
        [-0.2420,  0.2433]], grad_fn=<MmBackward0>)

## Hiding future words with causal attention
### Applying causal attention mask

In [74]:
# Your journey starts with one step

In [76]:
Q = sa_v1.W_q(inputs)
K = sa_v1.W_k(inputs)
V = sa_v1.W_v(inputs)
d_k = K.shape[-1]

attentions_scores = Q @ K.T
attention_weights = torch.softmax(attentions_scores / d_k**0.5, dim=-1)
        

In [78]:
attention_weights

tensor([[0.1566, 0.1692, 0.1691, 0.1690, 0.1662, 0.1699],
        [0.1623, 0.1748, 0.1746, 0.1614, 0.1624, 0.1644],
        [0.1619, 0.1743, 0.1741, 0.1621, 0.1628, 0.1649],
        [0.1667, 0.1729, 0.1728, 0.1611, 0.1631, 0.1634],
        [0.1560, 0.1622, 0.1623, 0.1756, 0.1702, 0.1738],
        [0.1703, 0.1784, 0.1781, 0.1547, 0.1595, 0.1589]],
       grad_fn=<SoftmaxBackward0>)

In [84]:
context_length = attentions_scores.shape[0]
mask = torch.tril(torch.ones(context_length,context_length))
mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [81]:
masked_simple = attention_weights * mask
masked_simple

tensor([[0.1566, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1623, 0.1748, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1619, 0.1743, 0.1741, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1729, 0.1728, 0.1611, 0.0000, 0.0000],
        [0.1560, 0.1622, 0.1623, 0.1756, 0.1702, 0.0000],
        [0.1703, 0.1784, 0.1781, 0.1547, 0.1595, 0.1589]],
       grad_fn=<MulBackward0>)

But now the rows are not normalized

In [83]:
masked_simple/masked_simple.sum(dim=-1, keepdim=True)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4814, 0.5186, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3173, 0.3416, 0.3412, 0.0000, 0.0000, 0.0000],
        [0.2475, 0.2567, 0.2565, 0.2392, 0.0000, 0.0000],
        [0.1887, 0.1964, 0.1965, 0.2125, 0.2059, 0.0000],
        [0.1703, 0.1784, 0.1781, 0.1547, 0.1595, 0.1589]],
       grad_fn=<DivBackward0>)

Above we first calculated `attn scores`(unnormalized) -> `attn weights`(normalized) -> `masked_values`(unnormalized) -> `attn values`(normalized)

We can instead skip attention weights and do in 2 steps.
`attn scores`-> `masked`  -> `attn values`(normalized)

In [90]:
mask = torch.triu(torch.ones(context_length,context_length), diagonal=1)
masked_simple = attention_scores.masked_fill(mask.bool(), -torch.inf)
masked_simple

tensor([[ 0.0523,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1536,  0.0973,    -inf,    -inf,    -inf,    -inf],
        [ 0.1483,  0.0933,  0.1177,    -inf,    -inf,    -inf],
        [ 0.1031,  0.0750,  0.0887, -0.0083,    -inf,    -inf],
        [ 0.0124, -0.0050,  0.0012, -0.0227,  0.1117,    -inf],
        [ 0.1584,  0.1161,  0.1368, -0.0114,  0.4714, -0.2042]],
       grad_fn=<MaskedFillBackward0>)

In [93]:
attention_weights = torch.softmax(masked_simple/keys.shape[0]**0.5, dim=-1)

In [94]:
attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5057, 0.4943, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3372, 0.3297, 0.3330, 0.0000, 0.0000, 0.0000],
        [0.2539, 0.2510, 0.2524, 0.2426, 0.0000, 0.0000],
        [0.1994, 0.1980, 0.1985, 0.1965, 0.2076, 0.0000],
        [0.1693, 0.1664, 0.1678, 0.1580, 0.1924, 0.1460]],
       grad_fn=<SoftmaxBackward0>)

### Masking addtional attention weights with dropout

In [95]:
torch.manual_seed(123)

layer = torch.nn.Dropout(p=0.5)
layer

Dropout(p=0.5, inplace=False)

In [96]:
layer(torch.ones(6,6))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

In [98]:
dropout = 0.3
1 / (1 - dropout)

1.4285714285714286

### Causal self-attention class

In [100]:
batch = torch.stack([inputs, inputs], dim=0)
batch.shape

torch.Size([2, 6, 3])

So you can say the shape of the input batch is supposed to be `(batch_size, num_Tokens, input vector dimensions)`

Below, `num_tokens` is to cut-off context length in case the batch contains a sequence of tokens that is longer than the context length, this does not happen in practice as that is taken care of before.

In [None]:
import torch.nn as nn

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length, bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=bias)
        self.W_k = nn.Linear(d_in, d_out, bias=bias)
        self.W_v = nn.Linear(d_in, d_out, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        Q = self.W_q(x) # (batch_size, num_tokens, d_out)
        K = self.W_k(x) # (batch_size, num_tokens, d_out)
        V = self.W_v(x) # (batch_size, num_tokens, d_out)
        d_k = K.shape[-1]
        
        attention_scores = Q @ K.transpose(1,2)
        attention_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens], # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
            -torch.inf
        )
    
        attention_weights = torch.softmax(attention_scores / d_k**0.5, dim=-1)
        context_vector = attention_weights @ V
        return context_vector

In [None]:
ca = CausalAttention(d_in, d_out, 0.0, context_length)

In [106]:
ca(batch)

tensor([[[-0.2410,  0.2378],
         [-0.2419,  0.2406],
         [-0.2418,  0.2403],
         [-0.2409,  0.2409],
         [-0.2392,  0.2355],
         [-0.2420,  0.2433]],

        [[-0.2410,  0.2378],
         [-0.2419,  0.2406],
         [-0.2418,  0.2403],
         [-0.2409,  0.2409],
         [-0.2392,  0.2355],
         [-0.2420,  0.2433]]], grad_fn=<UnsafeViewBackward0>)

### Stacking Multi-head attention

In [110]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout, context_length, bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, dropout, context_length, bias) for _ in range(num_heads)]
        )
    def forward(self, x):
        head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
        return head_outputs
        

In [112]:
context_length = batch.shape[1]
d_in = batch.shape[-1]
d_out = 2

ma = MultiHeadAttention(d_in, d_out, 2, 0.0, context_length)
ma(batch)

tensor([[[-0.3948,  0.5029, -0.6581,  0.2660],
         [-0.3927,  0.5003, -0.6597,  0.2673],
         [-0.3929,  0.5005, -0.6594,  0.2670],
         [-0.3953,  0.5035, -0.6584,  0.2633],
         [-0.3979,  0.5066, -0.6540,  0.2605],
         [-0.3933,  0.5011, -0.6610,  0.2659]],

        [[-0.3948,  0.5029, -0.6581,  0.2660],
         [-0.3927,  0.5003, -0.6597,  0.2673],
         [-0.3929,  0.5005, -0.6594,  0.2670],
         [-0.3953,  0.5035, -0.6584,  0.2633],
         [-0.3979,  0.5066, -0.6540,  0.2605],
         [-0.3933,  0.5011, -0.6610,  0.2659]]], grad_fn=<CatBackward0>)