## Masked Attention

### Imports

In [31]:
import math
import torch
import numpy as np
import torch.nn.functional as F

### Masked Attention Class

In [None]:
class MaskedAttention(torch.nn.Module):
    def __init__(self, embed_dim, context_len):
        super().__init__()
        self.embed_dim = embed_dim
        self.context_len = context_len
        
        self.K = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.Q = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.V = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)

        self.mask = torch.tril(torch.ones(self.context_len, self.context_len))
    
    def forward(self, x):
        B,T,C = x.shape

        Key = self.K(x)
        Query = self.Q(x)
        Value = self.V(x)

        print(Key.shape)
        print(Query.shape)

        attention_scores = Query@Key.transpose(-2, -1)
        print(attention_scores)
        print(attention_scores.shape)

        scaled_attn_scores = attention_scores/math.sqrt(Key.size(-1))
        print(scaled_attn_scores)
        print(scaled_attn_scores.shape)

        masked_scaled_attn_scores = scaled_attn_scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        print(masked_scaled_attn_scores)
        print(masked_scaled_attn_scores.shape)

        normalized_masked_scaled_attn_scores = F.softmax(masked_scaled_attn_scores, dim=-1)
        print(normalized_masked_scaled_attn_scores)
        print(normalized_masked_scaled_attn_scores.shape)

        context_vectors = normalized_masked_scaled_attn_scores@Value

        return context_vectors


In [47]:
x = torch.rand((1, 4, 8), dtype=torch.float)

attn = MaskedAttention(8, 4)
print(attn.forward(x))
context_vectors = attn.forward(x)
print(context_vectors)
print(context_vectors.shape)

torch.Size([1, 4, 8])
torch.Size([1, 4, 8])
tensor([[[-0.0484,  0.0866,  0.0112, -0.0115],
         [-0.0723,  0.2632,  0.1180,  0.0911],
         [-0.1217,  0.2157,  0.0313,  0.0124],
         [-0.0566,  0.4067,  0.1926,  0.2012]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])
tensor([[[-0.0171,  0.0306,  0.0040, -0.0041],
         [-0.0256,  0.0930,  0.0417,  0.0322],
         [-0.0430,  0.0763,  0.0111,  0.0044],
         [-0.0200,  0.1438,  0.0681,  0.0711]]], grad_fn=<DivBackward0>)
torch.Size([1, 4, 4])
tensor([[[-0.0171,    -inf,    -inf,    -inf],
         [-0.0256,  0.0930,    -inf,    -inf],
         [-0.0430,  0.0763,  0.0111,    -inf],
         [-0.0200,  0.1438,  0.0681,  0.0711]]], grad_fn=<MaskedFillBackward0>)
torch.Size([1, 4, 4])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4704, 0.5296, 0.0000, 0.0000],
         [0.3142, 0.3540, 0.3317, 0.0000],
         [0.2291, 0.2698, 0.2502, 0.2509]]], grad_fn=<SoftmaxBackward0>)
torch.Size([1, 4, 4])
tensor([[[