In [1]:
import torch 
import torch.nn as nn

In [2]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.wq = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.wk = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.wv = nn.Linear(d_in, d_out,bias=qkv_bias)
    
    def forward(self, x):
        query= self.wq(x)
        keys = self.wk(x)
        values = self.wv(x)

        attention_score = query @ keys.T
        attention_weights = torch.softmax(attention_score/(keys.shape[-1]**0.5),dim=-1)
        context_matrix= attention_weights @ values
        return context_matrix

In [3]:
torch.manual_seed(123)
inputs = torch.rand((6,3))

d_in = 3
d_out = 2
sa_v2 = SelfAttentionV2(d_in,d_out)

In [7]:
queries = sa_v2.wq(inputs)
keys = sa_v2.wk(inputs)
values = sa_v2.wv(inputs)

attention_scores = queries @ keys.T
attention_weights = torch.softmax(attention_scores/(keys.shape[-1]**0.5), dim= -1)

In [8]:
print(attention_weights)

tensor([[0.1769, 0.1593, 0.1697, 0.1673, 0.1681, 0.1588],
        [0.1885, 0.1521, 0.1719, 0.1689, 0.1683, 0.1502],
        [0.1718, 0.1630, 0.1682, 0.1671, 0.1673, 0.1626],
        [0.1861, 0.1537, 0.1714, 0.1688, 0.1681, 0.1519],
        [0.1757, 0.1600, 0.1695, 0.1669, 0.1682, 0.1597],
        [0.1877, 0.1521, 0.1723, 0.1680, 0.1690, 0.1508]],
       grad_fn=<SoftmaxBackward0>)


In [9]:
context_length = attention_scores.shape[0]
print(context_length)

6


In [11]:
mask = torch.tril(torch.ones(context_length, context_length))
print(mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [12]:
masked_attention = mask*attention_weights
print(masked_attention)

tensor([[0.1769, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1885, 0.1521, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1718, 0.1630, 0.1682, 0.0000, 0.0000, 0.0000],
        [0.1861, 0.1537, 0.1714, 0.1688, 0.0000, 0.0000],
        [0.1757, 0.1600, 0.1695, 0.1669, 0.1682, 0.0000],
        [0.1877, 0.1521, 0.1723, 0.1680, 0.1690, 0.1508]],
       grad_fn=<MulBackward0>)


In [16]:
masked_attention.sum(dim=-1,keepdim= True)

tensor([[0.1769],
        [0.3406],
        [0.5030],
        [0.6800],
        [0.8403],
        [1.0000]], grad_fn=<SumBackward1>)

In [17]:
masked_attention/masked_attention.sum(dim=-1,keepdim= True)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5534, 0.4466, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3417, 0.3240, 0.3343, 0.0000, 0.0000, 0.0000],
        [0.2737, 0.2260, 0.2520, 0.2483, 0.0000, 0.0000],
        [0.2091, 0.1904, 0.2018, 0.1987, 0.2002, 0.0000],
        [0.1877, 0.1521, 0.1723, 0.1680, 0.1690, 0.1508]],
       grad_fn=<DivBackward0>)