In [6]:
import torch
from torch import nn
import math
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn.functional as F

In [7]:
class TokenEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.table = nn.Embedding(vocab_size, d_model)
    def forward(self,x):
        return self.table(x)
        

In [8]:
class PositionalEmbeddings(nn.Module):
    def __init__(self, max_seq_length, d_model):
        super().__init__()
        self.table = nn.Embedding(max_seq_length, d_model)
    def forward(self,x):
        return self.table(x)

In [9]:
class Attention(nn.Module):
    def __init__(self, embedding_dim: int, attention_dim: int, dropout_val:float, max_sequence_length: int):
        super().__init__()
        self.qw = nn.Linear(embedding_dim, attention_dim)
        self.kw = nn.Linear(embedding_dim, attention_dim)
        self.vw = nn.Linear(embedding_dim, attention_dim)

        self.dropout = nn.Dropout(p = dropout_val)

        mask = torch.tril(torch.ones(1,max_sequence_length, max_sequence_length))
        self.register_buffer('mask', mask)

    def forward(self, x, cache = None):
        q = self.qw(x)
        k = self.kw(x)
        v = self.vw(x)
        d_model = q.shape[-1]
        if cache is not None:
            prev_k,prev_v = cache
            k = torch.cat([prev_k,k], dim = 1)
            v = torch.cat([prev_v,v], dim = 1)
        new_kv = (k,v)
        qk = torch.matmul(q,torch.transpose(k,-1,-2))
        qk = qk / math.sqrt(d_model)
        new_length = q.shape[1]
        total_length = k.shape[1]
        past_length = total_length - new_length

        if cache is not None:
            d_model = qk.shape[-1]
            mask = self.mask[:,past_length:total_length,:d_model]
            qk = qk.masked_fill(mask == 0, float("-inf"))
        else:
            mask = self.mask[:,:new_length,:new_length]
            qk = qk.masked_fill(mask == 0, float("-inf"))
    
        qk = torch.softmax(qk, dim = -1)
        
        qk = self.dropout(qk)
        
        qkv = torch.matmul(qk,v)         
        return qkv,new_kv

        
        

