## Input sequence: "Dream big and work for it"

In [33]:
import torch

inputs = torch.tensor(
    [[0.72, 0.45, 0.31], # Dream    (x^1)
     [0.75, 0.20, 0.55], # big      (x^2)
     [0.30, 0.80, 0.40], # and      (x^3)
     [0.85, 0.35, 0.60], # work     (x^4)
     [0.55, 0.15, 0.75], # for      (x^5)
     [0.25, 0.20, 0.85]] # it       (x^6)
)

# Corresponding words
words = ['Dream', 'big', 'and', 'work', 'for', 'it']

## Class for implementing self attention

In [34]:
from torch import nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

## Context vectors corresponding to inputs

In [35]:
d_in = inputs.shape[-1]
d_out = 2

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0184,  0.1495],
        [-0.0180,  0.1502],
        [-0.0183,  0.1495],
        [-0.0178,  0.1505],
        [-0.0177,  0.1506],
        [-0.0177,  0.1507]], grad_fn=<MmBackward0>)


## Final attention weights

In [36]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1582, 0.1729, 0.1456, 0.1732, 0.1771, 0.1730],
        [0.1590, 0.1739, 0.1444, 0.1738, 0.1771, 0.1718],
        [0.1568, 0.1734, 0.1433, 0.1739, 0.1785, 0.1741],
        [0.1570, 0.1750, 0.1403, 0.1751, 0.1793, 0.1732],
        [0.1594, 0.1745, 0.1435, 0.1742, 0.1772, 0.1711],
        [0.1600, 0.1744, 0.1441, 0.1741, 0.1767, 0.1707]],
       grad_fn=<SoftmaxBackward0>)


## Lower triangular matrix (mask)

In [37]:
context_length = inputs.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


## Attention weights after applying mask

In [38]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1582, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1590, 0.1739, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1568, 0.1734, 0.1433, 0.0000, 0.0000, 0.0000],
        [0.1570, 0.1750, 0.1403, 0.1751, 0.0000, 0.0000],
        [0.1594, 0.1745, 0.1435, 0.1742, 0.1772, 0.0000],
        [0.1600, 0.1744, 0.1441, 0.1741, 0.1767, 0.1707]],
       grad_fn=<MulBackward0>)


## Attention weights normalized

In [39]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(row_sums.shape)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

torch.Size([6, 1])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4777, 0.5223, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3312, 0.3662, 0.3026, 0.0000, 0.0000, 0.0000],
        [0.2425, 0.2703, 0.2167, 0.2704, 0.0000, 0.0000],
        [0.1923, 0.2105, 0.1732, 0.2102, 0.2138, 0.0000],
        [0.1600, 0.1744, 0.1441, 0.1741, 0.1767, 0.1707]],
       grad_fn=<DivBackward0>)


## Attention scores

In [40]:
print(attn_scores)

tensor([[ 0.1278,  0.2536,  0.0104,  0.2559,  0.2873,  0.2540],
        [ 0.1236,  0.2499, -0.0124,  0.2494,  0.2758,  0.2329],
        [ 0.1458,  0.2882,  0.0182,  0.2916,  0.3285,  0.2934],
        [ 0.1517,  0.3052, -0.0074,  0.3055,  0.3393,  0.2904],
        [ 0.1222,  0.2499, -0.0263,  0.2477,  0.2714,  0.2224],
        [ 0.1160,  0.2386, -0.0314,  0.2357,  0.2571,  0.2075]],
       grad_fn=<MmBackward0>)


## Upper triangular matrix

In [41]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


## Ones are converted to -ve infinity

In [42]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[ 0.1278,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1236,  0.2499,    -inf,    -inf,    -inf,    -inf],
        [ 0.1458,  0.2882,  0.0182,    -inf,    -inf,    -inf],
        [ 0.1517,  0.3052, -0.0074,  0.3055,    -inf,    -inf],
        [ 0.1222,  0.2499, -0.0263,  0.2477,  0.2714,    -inf],
        [ 0.1160,  0.2386, -0.0314,  0.2357,  0.2571,  0.2075]],
       grad_fn=<MaskedFillBackward0>)


## Attention weights after taking softmax

In [43]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4777, 0.5223, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3312, 0.3662, 0.3026, 0.0000, 0.0000, 0.0000],
        [0.2425, 0.2703, 0.2167, 0.2704, 0.0000, 0.0000],
        [0.1923, 0.2105, 0.1732, 0.2102, 0.2138, 0.0000],
        [0.1600, 0.1744, 0.1441, 0.1741, 0.1767, 0.1707]],
       grad_fn=<SoftmaxBackward0>)


## Ones matrix

In [None]:
example = torch.ones(inputs.shape[0], inputs.shape[0]) #B
print(example)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


## Random dropout with 50% probability

In [50]:
torch.manual_seed(123)
dropout = nn.Dropout(0.5) #A
dropout(example)

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

## Attention weights after dropout mask

In [51]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6623, 0.7325, 0.6052, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5407, 0.4335, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4210, 0.0000, 0.4204, 0.0000, 0.0000],
        [0.0000, 0.3489, 0.2882, 0.3482, 0.3535, 0.0000]],
       grad_fn=<MulBackward0>)
