In [None]:
!pip install torch

In [188]:
import torch

inputs = torch.tensor([
    [0.43, 0.15, 0.891],  # "Your"
    [0.55, 0.87, 0.66],   # "journey"
    [0.57, 0.85, 0.641],  # "starts"
    [0.22, 0.58, 0.331], # "with"
    [0.77, 0.25, 0.101],  # "one"
    [0.05, 0.80, 0.551]   # "step"
])

print(inputs.shape)


torch.Size([6, 3])


In [189]:
import torch.nn as nn

class SelfAttention(nn.Module):
  def __init__(self,in_dim,out_dim):
    super().__init__()
    self.W_K = nn.Linear(in_dim,out_dim,bias=False)
    self.W_V = nn.Linear(in_dim,out_dim,bias=False)
    self.W_Q = nn.Linear(in_dim,out_dim,bias=False)
  def forward(self,x):
    keys =self.W_K(x)
    values = self.W_V(x)
    queries =self.W_Q(x)

    attention_scores = queries @ keys.T
    attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5,dim=-1)

    return attention_weights @ values

In [190]:
self_attention = SelfAttention(inputs.shape[-1],2)
keys =self_attention.W_K(inputs)
values = self_attention.W_V(inputs)
queries =self_attention.W_Q(inputs)

In [191]:
attention_scores = queries @ keys.T
attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5,dim=-1)
attention_weights

tensor([[0.1745, 0.1637, 0.1639, 0.1650, 0.1702, 0.1626],
        [0.1675, 0.1694, 0.1697, 0.1624, 0.1705, 0.1606],
        [0.1671, 0.1696, 0.1698, 0.1624, 0.1704, 0.1607],
        [0.1670, 0.1685, 0.1686, 0.1640, 0.1689, 0.1630],
        [0.1607, 0.1714, 0.1714, 0.1646, 0.1665, 0.1653],
        [0.1705, 0.1671, 0.1674, 0.1633, 0.1706, 0.1612]],
       grad_fn=<SoftmaxBackward0>)

In [192]:
attention_weights

tensor([[0.1745, 0.1637, 0.1639, 0.1650, 0.1702, 0.1626],
        [0.1675, 0.1694, 0.1697, 0.1624, 0.1705, 0.1606],
        [0.1671, 0.1696, 0.1698, 0.1624, 0.1704, 0.1607],
        [0.1670, 0.1685, 0.1686, 0.1640, 0.1689, 0.1630],
        [0.1607, 0.1714, 0.1714, 0.1646, 0.1665, 0.1653],
        [0.1705, 0.1671, 0.1674, 0.1633, 0.1706, 0.1612]],
       grad_fn=<SoftmaxBackward0>)

In [193]:
mask = torch.triu(torch.full(attention_weights.shape, -torch.inf), diagonal=1)
attention_scores = attention_weights+mask

In [194]:
attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5,dim=-1)
attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4996, 0.5004, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3329, 0.3335, 0.3336, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2503, 0.2503, 0.2495, 0.0000, 0.0000],
        [0.1991, 0.2006, 0.2006, 0.1997, 0.1999, 0.0000],
        [0.1671, 0.1667, 0.1667, 0.1663, 0.1671, 0.1660]],
       grad_fn=<SoftmaxBackward0>)

In [195]:
#Masked attention
attention_weights @ values

tensor([[-0.7395,  0.1178],
        [-0.8368,  0.3166],
        [-0.8662,  0.3838],
        [-0.7741,  0.3621],
        [-0.7177,  0.3745],
        [-0.7021,  0.3596]], grad_fn=<MmBackward0>)

In [196]:
dropout = nn.Dropout(0.35)

In [197]:
dropout(attention_weights)

tensor([[1.5385, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7687, 0.7698, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5122, 0.5131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3850, 0.3850, 0.0000, 0.0000, 0.0000],
        [0.3063, 0.3087, 0.0000, 0.3072, 0.3076, 0.0000],
        [0.0000, 0.2565, 0.2565, 0.2558, 0.2571, 0.2554]],
       grad_fn=<MulBackward0>)

In [198]:
#now gonna prcess inputs in batches
input_batch = torch.stack((inputs,inputs),dim=0)

In [210]:
class MaskedSelfAttention(nn.Module):
  def __init__(self,in_dim,out_dim,context_len,bias=False,dropout=0.25):
    super().__init__()
    self.d_out = out_dim
    self.W_K = nn.Linear(in_dim,out_dim,bias= False)
    self.W_Q = nn.Linear(in_dim,out_dim,bias= False)
    self.W_V = nn.Linear(in_dim,out_dim,bias= False)
    self.dropout=nn.Dropout(dropout)
    self.register_buffer('mask',torch.triu(torch.ones(context_len,context_len),diagonal=1))
  def forward(self,x):
    b, num_tokens , d_in = x.shape
    keys = self.W_K(x)
    queries = self.W_Q(x)
    values = self.W_V(x)
    attention_scores = queries @ keys.transpose(1,2)# here if shape is (2,6,2) , it will swap 1st index dimension and 2nd index dimension.
    attention_scores.masked_fill_(
        self.mask.bool()[:num_tokens,:num_tokens],-torch.inf
    )
    attention_weights  = torch.softmax(attention_scores/keys.shape[-1]**0.5,dim=-1)
    attention_weights =self.dropout(attention_weights)
    return attention_weights

In [212]:
context_len = input_batch.shape[1]
attention = MaskedSelfAttention(3,2,context_len)
attention.forward(input_batch)

tensor([[[1.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6692, 0.6641, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4479, 0.4426, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3262, 0.0000, 0.3365, 0.0000, 0.0000, 0.0000],
         [0.2801, 0.0000, 0.2579, 0.2733, 0.0000, 0.0000],
         [0.0000, 0.2264, 0.2267, 0.2209, 0.2299, 0.2178]],

        [[1.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6692, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.4426, 0.4429, 0.0000, 0.0000, 0.0000],
         [0.3262, 0.3362, 0.3365, 0.3344, 0.0000, 0.0000],
         [0.2801, 0.2581, 0.2579, 0.0000, 0.0000, 0.0000],
         [0.2116, 0.2264, 0.0000, 0.0000, 0.2299, 0.2178]]],
       grad_fn=<MulBackward0>)