In [18]:
import torch
from torch import nn
import math
import torch.nn.functional as F

In [57]:
class Attention(nn.Module):
    def __init__(self, d_model, attention_dim, max_sequence_length, dropout):
        super().__init__()
        self.qw = nn.Linear(d_model, attention_dim)

        self.qkvw = nn.Linear(attention_dim, attention_dim)

        self.dropout = nn.Dropout(p = dropout)

        mask = torch.tril(torch.ones(max_sequence_length, max_sequence_length))

        self.register_buffer("mask", mask)

    def forward(self, x, k, v, inference = False):

        q = self.qw(x)

        batch_size, seq_length, attention_dim = k.shape

        new_length = q.shape[1]

        past_length = seq_length - new_length
        
        qk = q @ torch.transpose(k,-1,-2)

        if inference == True:
            mask = self.mask[past_length:seq_length, :seq_length]
            
        else:
            mask = self.mask[:seq_length, :seq_length]

        qk = qk.masked_fill(mask == 0, float("-inf"))

        qk = qk / math.sqrt(attention_dim)

        qk = F.softmax(qk, dim = -1)

        qk = self.dropout(qk)

        qkv = qk @ v

        qkv = self.qkvw(qkv)

        return qkv
        


       

In [59]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, max_sequence_length, heads, dropout):
        super().__init__()
        
        assert d_model % heads == 0, "d_model should be perfectly divisible by heads"

        attention_dim = int(d_model // heads)

        self.kw = nn.Linear(d_model, attention_dim)
        self.vw = nn.Linear(d_model, attention_dim)
        
        self.attention_heads = nn.ModuleList([Attention(d_model, attention_dim, max_sequence_length, dropout) for _ in range(heads)])
        self.kv_cache = (None,None)


    def forward(self, x, inference = False):

        k = self.kw(x)
        v = self.vw(x)
        if inference == True:
            k_prev, v_prev = self.kv_cache
            if k_prev is not None:
                k = torch.cat([k_prev, k], dim = 1)
                v = torch.cat([v_prev, v], dim = 1)
            self.kv_cache = (k,v)

        mha = [attention(x,k,v, inference) for attention in self.attention_heads]
        mha = torch.cat(mha, -1)

        return mha
            
        

        

In [61]:
rand = torch.rand(3, 2,6)
rand

tensor([[[0.8173, 0.0368, 0.3520, 0.0292, 0.4106, 0.2075],
         [0.6018, 0.9990, 0.0660, 0.8565, 0.5285, 0.7071]],

        [[0.7028, 0.4620, 0.4774, 0.1843, 0.2528, 0.5330],
         [0.2417, 0.5211, 0.0090, 0.7061, 0.4615, 0.9414]],

        [[0.1067, 0.7722, 0.1677, 0.8867, 0.3147, 0.4183],
         [0.0540, 0.6101, 0.3978, 0.9347, 0.1291, 0.5395]]])

In [63]:
mha = MultiHeadAttention(6, 10, 3, 0.1)
mha

MultiHeadAttention(
  (kw): Linear(in_features=6, out_features=2, bias=True)
  (vw): Linear(in_features=6, out_features=2, bias=True)
  (attention_heads): ModuleList(
    (0-2): 3 x Attention(
      (qw): Linear(in_features=6, out_features=2, bias=True)
      (qkvw): Linear(in_features=2, out_features=2, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [65]:
mha(rand, inference = True)

tensor([[[ 0.3353,  0.6073, -0.2341, -0.2125,  0.0043,  0.2958],
         [ 0.2257,  0.5095, -0.4416,  0.0356, -0.4351,  0.3016]],

        [[ 0.2523,  0.5537, -0.2957, -0.1468, -0.1158,  0.3171],
         [ 0.1761,  0.4954, -0.3627, -0.0784, -0.2791,  0.2574]],

        [[-0.0721,  0.5081, -0.4360,  0.0534, -0.4521,  0.2744],
         [-0.0627,  0.5083, -0.4326,  0.0478, -0.4433,  0.2764]]],
       grad_fn=<CatBackward0>)