In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional

In [17]:
@dataclass
class ModelArgs:
    dim:int = 4096
    n_heads:int = 32 #number of heads for queries
    n_layers:int = 32
    n_kv_heads:Optional[int] = None #Number of heads for K and V matrices
    vocab_size:int = -1 #set this later when we import tokenizer
    multiple_of: int = 256
    ffn_dim_multiplier:Optional[float] = None
    norm_eps: float = 1e-5

    #KV Cache terms
    max_batch_size:int = 32
    max_seq_len: int = 2048

    device:str = None

    

In [18]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta = 10000.0):
    assert head_dim%2 == 0 #paper requires even nuber of dimensions

    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0/ (theta^(theta_numerator/head_dim)).to(device)

    m = torch.arange(seq_len, device = device)

    freqs = torch.outer(m, theta).float()

    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)

    return freqs_complex


def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device:str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))

    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex + freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)



class RMSNorm(nn.Module):

    def __init__(self, dim:int, eps:float = 1e-5):
        super().__init__9
        self.eps = eps
        self.dim = dim
        self.gamma = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True)+self.eps)
    
    def forward(self, x):
        return self.gamma * self._norm(x.float()).type_as(x)

In [19]:
class EncoderBlock(nn.Module):

    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_heads = args.n_heads
        self.dim = args.dim 
        self.head_dim = args.dim // args.n_heads

        self.attention = SelfAttention(args)
        self.feed_forward = FeedForward(args)

        #RMS before attention
        self.attention_norm = RMSNorm(args.dim, eps = args.norm_eps)
        #RMS before ffm
        self.ffn_norm = RMSNorm(args.dim, eps = args.norm_eps)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_complex)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out




In [20]:
class SelfAttention(nn.Module):

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_heads_q = args.n_heads

        self.n_rep = self.n_heads_q // self.n_kv_heads

        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads + self.head_dim, bias = False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads + self.head_dim, bias = False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads + self.head_dim, bias = False)
        self.wo = nn.Linear(args.n_heads + self.head_dim, args.dim, bias = False)

        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        pass



In [21]:
class FeedForward(nn.Module):

    def __init__(self, args: ModelArgs):
        super().__init__()
        hidden_dim = 4*args.dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier + hidden_dim)
        hidden = args.multiple_of * ((hidden + args.multiple_of - 1)) // args.multiple_of

        self.w1 = nn.Linear(args.dim, hidden_dim, bias = False)
        self.w2 = nn.Linear(hidden_dim, args.dim, bias = False)
        self.w3 = nn.Linear(args.dim, hidden_dim, bias = False)

    def forward(self, x):
        swish = torch.silu(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        x = self.w2(x)   
        return x 

In [22]:
class Transformer(nn.Module):

    def __init__(self,args: ModelArgs) -> None:
        super().__init__()

        assert args.vocab_size != -1 #we need to set the vocab size first

        self.args = args
        self.n_heads = args.n_heads
        self.vocab_size = args.vocab_size
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args.dim, eps = args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias = False)

        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device = self.args.device)


    def forward(self, tokens: torch.tensor, start_pos: int):
        #(batch, seq_len)
        batch_size, seq_len = tokens.shape
        assert seq_len == 1 #only one token can be processed at a time#

        h = self.tok_embeddings(tokens)

        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float()
        return output



