In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor(
   [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
print('-' * 100)
print('T x Embed_Size')
print('-' * 100)
print(inputs.shape)
print('-' * 100)

----------------------------------------------------------------------------------------------------
T x Embed_Size
----------------------------------------------------------------------------------------------------
torch.Size([6, 3])
----------------------------------------------------------------------------------------------------


In [4]:
torch.manual_seed(123)

d_in = inputs.shape[1]
d_out = inputs.shape[1]

Wq = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wk = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wv = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
queries = inputs @ Wq
keys = inputs @ Wk
values = inputs @ Wv

queries.shape, keys.shape, values.shape

(torch.Size([6, 3]), torch.Size([6, 3]), torch.Size([6, 3]))

In [8]:
attention_scores = queries @ keys.T
attention_scores, attention_scores.shape


(tensor([[0.7616, 0.8765, 0.8746, 0.4349, 0.5941, 0.4877],
         [1.7872, 2.0141, 2.0091, 0.9952, 1.3538, 1.1227],
         [1.7646, 1.9901, 1.9852, 0.9834, 1.3383, 1.1091],
         [1.0664, 1.1947, 1.1916, 0.5897, 0.8004, 0.6667],
         [0.8601, 0.9968, 0.9950, 0.4947, 0.6817, 0.5516],
         [1.3458, 1.4957, 1.4915, 0.7374, 0.9968, 0.8366]]),
 torch.Size([6, 6]))

### Causal attention
Causal attention, or masked attention, is a variant of self-attention where the model only looks at past and current tokens,
unlike standard self-attention which considers the entire input sequence. This ensures that during attention computation,
only tokens occurring before or at the current position are factored in.

#### Steps for computing masked attention weights
**Attention_Scores (unnormzalied)** -> (1) Mask with -∞ above diagonal -> **Masked_Attention_Scores** -> (2) Apply Softmax -> **Masked_Attention_Weight** 



In [16]:
T = inputs.shape[0]   # Context length
mask = torch.triu(torch.ones(T, T), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [22]:
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.7616,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.7872, 2.0141,   -inf,   -inf,   -inf,   -inf],
        [1.7646, 1.9901, 1.9852,   -inf,   -inf,   -inf],
        [1.0664, 1.1947, 1.1916, 0.5897,   -inf,   -inf],
        [0.8601, 0.9968, 0.9950, 0.4947, 0.6817,   -inf],
        [1.3458, 1.4957, 1.4915, 0.7374, 0.9968, 0.8366]])

#### Normalize to get attention weights

In [24]:
d_k = keys.shape[-1]
attention_weights = torch.softmax( masked / d_k ** 0.5 , dim=1 )
attention_weights, attention_weights.shape

(tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4673, 0.5327, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3054, 0.3478, 0.3468, 0.0000, 0.0000, 0.0000],
         [0.2557, 0.2753, 0.2748, 0.1942, 0.0000, 0.0000],
         [0.2051, 0.2220, 0.2217, 0.1661, 0.1851, 0.0000],
         [0.1837, 0.2003, 0.1998, 0.1293, 0.1501, 0.1369]]),
 torch.Size([6, 6]))

### Use modified attention weights to compute context vector

In [25]:
Z = attention_weights @ values
Z, Z.shape

(tensor([[0.4976, 0.9655, 0.7614],
         [0.7159, 1.1712, 1.1589],
         [0.7789, 1.2294, 1.2769],
         [0.7244, 1.1291, 1.1867],
         [0.6756, 1.0523, 1.1366],
         [0.6783, 1.0441, 1.1252]]),
 torch.Size([6, 3]))

## Compact Causal Attention

In [33]:
import torch.nn as nn
torch.manual_seed(143)

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super(CausalAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Wq = nn.Linear(d_in, d_out)
        self.Wv = nn.Linear(d_in, d_out)
        self.Wk = nn.Linear(d_in, d_out)

    def forward(self, x):
        # x -> (T, d_in) or (T, embed_size)
        T = x.shape[0]
        queries = self.Wq(x)  # (T, d_out)
        keys = self.Wk(x)     # (T, d_out)
        values = self.Wv(x)   # (T, d_out)

        # compute attention scores
        attention_scores = queries @ keys.T  # (T, T)

        # compute masked attention weights
        mask = torch.triu(torch.ones(T, T), diagonal=1) # (T, T)
        masked_attention_scores = attention_scores.masked_fill(mask.bool(), -torch.inf)
        
        # compute attention weights
        attention_weights = torch.softmax( masked_attention_scores / self.d_out ** 0.5 , 1) # (T, T)
        print('-'* 100)
        print('Masked Attention Weights')
        print(attention_weights)
        print('-'* 100)

        # compute context vector
        # Z_2 = (a2_1 * v_1) + (a2_2 * v_2) + .. (a2_T * v_T)
        Z = attention_weights @ values
        return Z

selfattn = CausalAttention(d_in=5, d_out=5, qkv_bias=False)
X = torch.rand(6, 5)
Z = selfattn(X)
print('-'* 100)
print('Z')
print(Z)
print(Z.shape)
print('-'* 100)

----------------------------------------------------------------------------------------------------
Masked Attention Weights
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4902, 0.5098, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3391, 0.3367, 0.3242, 0.0000, 0.0000, 0.0000],
        [0.2667, 0.2456, 0.2456, 0.2421, 0.0000, 0.0000],
        [0.2062, 0.2073, 0.2015, 0.1890, 0.1961, 0.0000],
        [0.1847, 0.1625, 0.1674, 0.1682, 0.1513, 0.1659]],
       grad_fn=<SoftmaxBackward0>)
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
Z
tensor([[-0.8628,  0.0772,  0.2997, -0.1688,  0.8274],
        [-0.7920,  0.1168,  0.2625, -0.1527,  0.7410],
        [-0.8819,  0.1180,  0.2414, -0.1056,  0.7409],
        [-0.9319,  0.0779,  0.2426, -0.0683,  0.7029],
        [-0.9264,  0.1476,  0.2210, -0.0710,  0.6686],
        [-0

### Dropout in attention weights

In [None]:
import torch.nn as nn
torch.manual_seed(143)

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False, dropout=0.5):
        super(CausalAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Wq = nn.Linear(d_in, d_out)
        self.Wv = nn.Linear(d_in, d_out)
        self.Wk = nn.Linear(d_in, d_out)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x -> (T, d_in) or (T, embed_size)
        T = x.shape[0]
        queries = self.Wq(x)  # (T, d_out)
        keys = self.Wk(x)     # (T, d_out)
        values = self.Wv(x)   # (T, d_out)

        # compute attention scores
        attention_scores = queries @ keys.T  # (T, T)

        # compute masked attention weights
        mask = torch.triu(torch.ones(T, T), diagonal=1) # (T, T)
        masked_attention_scores = attention_scores.masked_fill(mask.bool(), -torch.inf)
        
        # compute attention weights
        attention_weights = torch.softmax( masked_attention_scores / self.d_out ** 0.5 , 1) # (T, T)
        print('-'* 100)
        print('Masked Attention Weights')
        print(attention_weights)
        print('-'* 100)

        # compute context vector
        # Z_2 = (a2_1 * v_1) + (a2_2 * v_2) + .. (a2_T * v_T)
        Z = attention_weights @ values
        return Z

selfattn = CausalAttention(d_in=5, d_out=5, qkv_bias=False)
X = torch.rand(6, 5)
Z = selfattn(X)
print('-'* 100)
print('Z')
print(Z)
print(Z.shape)
print('-'* 100)