In [88]:
import torch

# sample token embeddings:
inputs = torch.tensor(
  [
    [0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts   (x^3)
    [0.22, 0.58, 0.33], # with     (x^4)
    [0.77, 0.25, 0.10], # one      (x^5)
    [0.05, 0.80, 0.55]  # step     (x^6)
  ] 
)


In [89]:
x_2 = inputs[1] # journey
d_in = inputs.shape[1]
d_out = 2

`torch.nn.Parameter`: It's a tensor, just like the ones you use in PyTorch, but it's specifically designed to be used as a parameter in neural networks. When you use torch.nn.Parameter, it tells PyTorch that this tensor should be updated during training.

In [90]:
torch.manual_seed(42)
W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)


In [91]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

In [92]:
query_2, key_2, value_2

(tensor([-0.3519,  0.1483]),
 tensor([1.9692, 0.4159]),
 tensor([0.6229, 0.4434]))

In [93]:
# all weights
keys = inputs @ W_key
queries = inputs @ W_query
values = inputs @ W_value

keys.shape, queries.shape, values.shape

(torch.Size([6, 2]), torch.Size([6, 2]), torch.Size([6, 2]))

In [94]:
#attention scores
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([-0.4540, -0.6313, -0.6450, -0.2855, -0.7087, -0.1794])

In [95]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / (d_k ** 0.5), dim=-1)
attn_weights_2

tensor([0.1686, 0.1487, 0.1473, 0.1899, 0.1408, 0.2047])

In [96]:
# context vectors

context_2 = attn_weights_2 @ values
context_2

tensor([0.5633, 0.3251])

## Self Attention Class

In [97]:
import torch.nn as nn
import torch

class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vectors = attn_weights @ values

        return context_vectors
    

In [98]:
torch.manual_seed(42)

sa = SelfAttentionV1(d_in, d_out)
context_vec = sa(inputs)
context_vec

tensor([[0.5141, 0.3639],
        [0.5633, 0.3251],
        [0.5659, 0.3221],
        [0.5839, 0.2941],
        [0.6180, 0.2539],
        [0.5575, 0.3262]], grad_fn=<MmBackward0>)

In [99]:
# self attention using linear layers

import torch.nn as nn
import torch

class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vectors = attn_weights @ values

        return context_vectors
    

In [100]:
torch.manual_seed(789)

sa_v2 = SelfAttentionV2(d_in, d_out)
context_vec = sa_v2(inputs)
context_vec

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

## Causal Attention

1. Create mask with the shape of the attn weights
2. multiply the mask with the attn weights
3. apply softmax to the masked attn weights

In [101]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [102]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [103]:
masked_weights = attn_weights * mask_simple
masked_weights

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [104]:
row_sums = masked_weights.sum(dim=1, keepdim=True)
masked_weights = masked_weights / row_sums
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

In [105]:
# mask with -inf 
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked_weights = attn_weights.masked_fill(mask.bool(), -torch.inf)
masked_weights

tensor([[0.1921,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2041, 0.1659,   -inf,   -inf,   -inf,   -inf],
        [0.2036, 0.1659, 0.1662,   -inf,   -inf,   -inf],
        [0.1869, 0.1667, 0.1668, 0.1571,   -inf,   -inf],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658,   -inf],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MaskedFillBackward0>)

In [106]:
attn_weights = torch.softmax(masked_weights / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5068, 0.4932, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3392, 0.3303, 0.3304, 0.0000, 0.0000, 0.0000],
        [0.2531, 0.2495, 0.2495, 0.2478, 0.0000, 0.0000],
        [0.2021, 0.1998, 0.1998, 0.1987, 0.1996, 0.0000],
        [0.1698, 0.1666, 0.1666, 0.1652, 0.1666, 0.1650]],
       grad_fn=<SoftmaxBackward0>)


In [107]:
context_vec = attn_weights @ values
context_vec

tensor([[1.5058, 0.1444],
        [1.0703, 0.2919],
        [0.9275, 0.3191],
        [0.7224, 0.3529],
        [0.7253, 0.0870],
        [0.6055, 0.2591]], grad_fn=<MmBackward0>)

In [108]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

In [109]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6785, 0.6607, 0.6608, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4990, 0.4991, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3996, 0.0000, 0.3973, 0.0000, 0.0000],
        [0.0000, 0.3332, 0.3333, 0.3304, 0.3333, 0.0000]],
       grad_fn=<MulBackward0>)


### Causal Attention Class


In [110]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [111]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length, context_length),
                diagonal=1
            )
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] * 0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vectors = attn_weights @ values
        return context_vectors


In [112]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vec = ca(batch)
print(context_vec)


tensor([[[-0.5740,  0.2727],
         [-0.8709,  0.1008],
         [-0.8628,  0.1060],
         [-0.4789,  0.0051],
         [-0.4744,  0.1696],
         [-0.5888, -0.0388]],

        [[-0.5740,  0.2727],
         [-0.8709,  0.1008],
         [-0.8628,  0.1060],
         [-0.4789,  0.0051],
         [-0.4744,  0.1696],
         [-0.5888, -0.0388]]], grad_fn=<UnsafeViewBackward0>) torch.Size([2, 6, 2])
tensor([[[-0.4519,  0.2216],
         [-0.5893,  0.0029],
         [-0.6316, -0.0656],
         [-0.5685, -0.0853],
         [-0.5540, -0.0984],
         [-0.5308, -0.1089]],

        [[-0.4519,  0.2216],
         [-0.5893,  0.0029],
         [-0.6316, -0.0656],
         [-0.5685, -0.0853],
         [-0.5540, -0.0984],
         [-0.5308, -0.1089]]], grad_fn=<UnsafeViewBackward0>)
