In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from typing import List
import torch.nn.functional as F
import math

In [3]:
class VocabEmbeddings(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_size)

    def forward(self, x: List[int]):
        return self.embeddings(x)
        
        

In [4]:
class PositionalEmbeddings(nn.Module):
    def __init__(self, context_length, embed_size):
        super().__init__()
        self.embeddings = nn.Embedding(context_lenth, embed_size)

    def forward(self, x: List[int]):
        return self.embeddings(x)
    

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size: int, heads: int, context_length: int, dropout:float = 0.1):
        super().__init__()
        self.qw = nn.Linear(embed_size, embed_size)
        self.kw = nn.Linear(embed_size, embed_size)
        self.vw = nn.Linear(embed_size, embed_size)

        assert embed_size % heads == 0, "embed_size should be completely divisible by heads"

        self.heads = heads
        self.head_size = embed_size // heads

        self.outproj = nn.Linear(embed_size, embed_size)

        mask = torch.tril(torch.ones(1,1,context_length, context_length))
        self.register_buffer("mask", mask)

        self.dropout = nn.Dropout(p = dropout)

        self.kv_cache = (None, None)

    def forward(self, x):
        q = self.qw(x) # (batch_size, context_len, embed_size)
        k = self.kw(x) # (batch_size, context_len, embed_size)
        v = self.vw(x) # (batch_size, context_len, embed_size)

        batch_size, context_length, embed_size = q.shape

        q = q.reshape((batch_size, context_length, self.heads, self.head_size)) # (batch_size, context_len, heads, head_size)
        k = k.reshape((batch_size, context_length, self.heads, self.head_size)) # (batch_size, context_len, heads, head_size)
        v = v.reshape((batch_size, context_length, self.heads, self.head_size)) # (batch_size, context_len, heads, head_size)

        q = q.transpose(1,2) # (batch_size, heads, context_len, head_size)
        k = k.transpose(1,2) # (batch_size, heads, context_len, head_size)
        v = v.transpose(1,2) # (batch_size, heads, context_len, head_size)

        k = k.transpose(-1, -2) # (batch_size, heads, head_size, context_len)

        qk = torch.matmul(q,k) # (batch_size, heads, context_len, head_size) @ (batch_size, heads, head_size, context_len) = (batch_size, heads, context_len, context_len)

        qk = qk / math.sqrt(self.head_size)

        mask = self.mask[:,:,:context_length, :context_length]
        
        qk = qk.masked_fill(mask == 0, float("-inf")) # (batch_size, heads, context_len, context_len)
        
        qk = F.softmax(qk, dim = -1)

        qk = self.dropout(qk)

        qkv = torch.matmul(qk,v) # (batch_size, heads, context_len, context_len) @ (batch_size, heads, context_len, head_size)) = (batch_size, heads, context_len, head_size)

        qkv = qkv.transpose(1,2) # (batch_size, context_len, heads, head_size)

        qkv = qkv.reshape((batch_size, context_length, -1)) # (batch_size, context_len, heads, head_size)

        qkv = self.outproj(qkv)

        return qkv

    
            
        


        

        
        
        

In [10]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size: int, heads: int, context_length: int, feed_forward_depth:int, dropout:float = 0.1):
        super().__init__()
        self.block = MultiHeadAttention(embed_size, heads, context_length)

        self.layernorm1 = nn.LayerNorm(embed_size)
        self.layernorm2 = nn.LayerNorm(embed_size)

        self.FFN = nn.Sequential(nn.Linear(embed_size, embed_size * feed_forward_depth),
                                 nn.ReLU(),
                                 nn.Linear(embed_size * feed_forward_depth, embed_size),
                                 nn.Dropout(dropout))

    def forward(self, x):
        x = x + self.block(self.layernorm1(x))
        x = x + self.FFN(self.layernorm2(x))
        return x

