In [1]:
import torch

inputs = torch.tensor([
    [0.43,0.15,0.89],       # your
    [0.55,0.87,0.66],       # journey
    [0.57,0.85,0.64],       # starts
    [0.22,0.58,0.33],       # with
    [0.77,0.25,0.10],       # one
    [0.05,0.80,0.55]        # step
])

In [10]:
import torch.nn as nn

class selfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.w_query = nn.Linear(d_in,d_out , bias=qkv_bias)
        self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.w_value = nn.Linear(d_in,d_out, bias=qkv_bias)

    def forward(self, x):
        query = self.w_query(x)
        key = self.w_key(x)
        value = self.w_value(x)

        attn_score = query @ key.T
        attn_weight = torch.softmax(attn_score/key.shape[-1]**0.5, dim=-1)

        return attn_weight,attn_score

In [14]:
ca_v2 = selfAttention_v2(3,2)
attn_weight,attn_score = ca_v2(inputs)
print(attn_weight)
print(attn_score)

tensor([[0.1722, 0.1709, 0.1708, 0.1606, 0.1630, 0.1625],
        [0.1656, 0.1669, 0.1668, 0.1674, 0.1647, 0.1685],
        [0.1652, 0.1666, 0.1664, 0.1679, 0.1651, 0.1688],
        [0.1663, 0.1672, 0.1671, 0.1668, 0.1650, 0.1677],
        [0.1582, 0.1601, 0.1603, 0.1762, 0.1721, 0.1731],
        [0.1703, 0.1706, 0.1704, 0.1622, 0.1614, 0.1651]],
       grad_fn=<SoftmaxBackward0>)
tensor([[ 0.1794,  0.1683,  0.1672,  0.0801,  0.1015,  0.0970],
        [-0.0109,  0.0003, -0.0007,  0.0047, -0.0182,  0.0141],
        [-0.0252, -0.0135, -0.0144, -0.0020, -0.0259,  0.0056],
        [ 0.0054,  0.0127,  0.0119,  0.0094, -0.0058,  0.0175],
        [-0.2754, -0.2580, -0.2564, -0.1227, -0.1563, -0.1482],
        [ 0.1432,  0.1461,  0.1441,  0.0747,  0.0676,  0.0997]],
       grad_fn=<MmBackward0>)


In [7]:
attn_weight[1].sum()

tensor(1., grad_fn=<SumBackward0>)

In [8]:
attn_weight[1].sum(dim=-1)

tensor(1., grad_fn=<SumBackward1>)

In [12]:
context_length = attn_score.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [13]:
masked_simple = attn_weight* mask_simple
print(masked_simple)

tensor([[0.1602, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1679, 0.1630, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1679, 0.1630, 0.1629, 0.0000, 0.0000, 0.0000],
        [0.1685, 0.1644, 0.1645, 0.1679, 0.0000, 0.0000],
        [0.1674, 0.1643, 0.1643, 0.1689, 0.1663, 0.0000],
        [0.1685, 0.1642, 0.1642, 0.1682, 0.1675, 0.1674]],
       grad_fn=<MulBackward0>)


In [16]:
## Normalization such that all element in a row sum up to one
row_sums = masked_simple.sum(dim=1,keepdim=True)
masked_weigths = masked_simple/row_sums
print(masked_weigths)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5074, 0.4926, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3399, 0.3301, 0.3300, 0.0000, 0.0000, 0.0000],
        [0.2533, 0.2472, 0.2472, 0.2523, 0.0000, 0.0000],
        [0.2014, 0.1977, 0.1976, 0.2032, 0.2001, 0.0000],
        [0.1685, 0.1642, 0.1642, 0.1682, 0.1675, 0.1674]],
       grad_fn=<DivBackward0>)


### normalization before masking lead to data leakage, to prevent this we perform out operation on attn_score and after masking we do normalization

In [17]:
print(attn_score)

tensor([[ 0.1794,  0.1683,  0.1672,  0.0801,  0.1015,  0.0970],
        [-0.0109,  0.0003, -0.0007,  0.0047, -0.0182,  0.0141],
        [-0.0252, -0.0135, -0.0144, -0.0020, -0.0259,  0.0056],
        [ 0.0054,  0.0127,  0.0119,  0.0094, -0.0058,  0.0175],
        [-0.2754, -0.2580, -0.2564, -0.1227, -0.1563, -0.1482],
        [ 0.1432,  0.1461,  0.1441,  0.0747,  0.0676,  0.0997]],
       grad_fn=<MmBackward0>)


In [18]:
mask = torch.triu(torch.ones(context_length , context_length),diagonal=1)
masked = attn_score.masked_fill(mask.bool(),-torch.inf)
masked

tensor([[ 0.1794,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0109,  0.0003,    -inf,    -inf,    -inf,    -inf],
        [-0.0252, -0.0135, -0.0144,    -inf,    -inf,    -inf],
        [ 0.0054,  0.0127,  0.0119,  0.0094,    -inf,    -inf],
        [-0.2754, -0.2580, -0.2564, -0.1227, -0.1563,    -inf],
        [ 0.1432,  0.1461,  0.1441,  0.0747,  0.0676,  0.0997]],
       grad_fn=<MaskedFillBackward0>)

In [20]:
attn_weight_masked = torch.softmax(masked/2**0.5,dim=-1)
print(attn_weight)
print(attn_weight_masked)

tensor([[0.1722, 0.1709, 0.1708, 0.1606, 0.1630, 0.1625],
        [0.1656, 0.1669, 0.1668, 0.1674, 0.1647, 0.1685],
        [0.1652, 0.1666, 0.1664, 0.1679, 0.1651, 0.1688],
        [0.1663, 0.1672, 0.1671, 0.1668, 0.1650, 0.1677],
        [0.1582, 0.1601, 0.1603, 0.1762, 0.1721, 0.1731],
        [0.1703, 0.1706, 0.1704, 0.1622, 0.1614, 0.1651]],
       grad_fn=<SoftmaxBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4980, 0.5020, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3316, 0.3343, 0.3341, 0.0000, 0.0000, 0.0000],
        [0.2492, 0.2505, 0.2504, 0.2499, 0.0000, 0.0000],
        [0.1913, 0.1937, 0.1939, 0.2131, 0.2081, 0.0000],
        [0.1703, 0.1706, 0.1704, 0.1622, 0.1614, 0.1651]],
       grad_fn=<SoftmaxBackward0>)


### Masking addtional attention weights with dropouts

In [21]:
## Example how droptouts to be implimneted in the matrix
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [23]:
print(dropout(attn_weight_masked))

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0040, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6686, 0.6682, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5010, 0.5007, 0.0000, 0.0000, 0.0000],
        [0.3826, 0.3873, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3412, 0.3408, 0.0000, 0.3228, 0.3302]],
       grad_fn=<MulBackward0>)


### creating a compact class for causal attention

In [32]:
class causal_attention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        # self.d_in = d_in
        self.w_query = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self,x):
        b,num_tokens,d_in = x.shape
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)

        attn_score = queries @ keys.transpose(1,2)
        attn_score.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens],-torch.inf)
        attn_weight = torch.softmax(attn_score/keys.shape[-1]**0.5,dim=-1)
        attn_weight = dropout(attn_weight)

        context_vector = attn_weight @ values
        return context_vector

In [33]:
batch = torch.stack((inputs,inputs),dim=0)
d_in = 3
d_out =2
print(batch.shape)


torch.Size([2, 6, 3])


In [34]:
torch.manual_seed(123)
context_length  = batch.shape[1]
ca = causal_attention(d_in,d_out,context_length,0.0)
context_vecs = ca(batch)
print(context_vecs.shape)

torch.Size([2, 6, 2])


In [35]:
context_vecs

tensor([[[ 0.0000,  0.0000],
         [-0.4368,  0.2142],
         [-0.7751,  0.0077],
         [-0.9140, -0.2769],
         [ 0.0000,  0.0000],
         [-0.6906, -0.0974]],

        [[-0.9038,  0.4432],
         [ 0.0000,  0.0000],
         [-0.2883,  0.1414],
         [-0.9140, -0.2769],
         [-0.4416, -0.1410],
         [-0.5272, -0.1706]]], grad_fn=<UnsafeViewBackward0>)