In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional

In [14]:
class ModelArgs:
    d_model: int = 4096
    n_layers: int = 32
    num_heads: int = 32
    num_kv_heads: Optional[int] = None
    vocab_size: int = -1 # Later set in the build method
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    # Needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = None

In [11]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.g = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.square(x).mean(dim=-1, keepdim=True) + self.eps)
        x = x * rms
        return x * self.g

In [3]:
def position_freqs(max_seq_length: int, head_dim: int):
        assert head_dim % 2 == 0
        m = torch.arange(0, max_seq_length)
        theta = 1 / (10000 ** 2 * torch.arange(0, head_dim // 2) / head_dim)
        freqs = torch.outer(m, theta)
        freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_complex
    
def apply_rope(x: torch.Tensor, freqs_complex: torch.Tensor):
        x_complex = torch.view_as_complex(*x.float().reshape(x.shape[:-1], -1, 2))
        freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
        x_rotated = x_complex * freqs_complex
        x_out = torch.view_as_real(x_rotated)
        x_out = x_out.reshape(*x.shape)
        return x_out.type_as(x)


In [18]:
def repeat_kv(x: torch.tensor, n_rep: int) -> torch.tensor:
    batch_size, seq_len, num_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return(
        x[:,:,:,None,:]
        .expand(batch_size, seq_len, num_kv_heads, n_rep, head_dim)
        .reshape(batch_size, seq_len, num_kv_heads * n_rep, head_dim)
    )

In [19]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        assert args.num_heads % args.num_kv_heads == 0

        self.num_heads = args.num_heads
        self.num_kv_heads = args.num_kv_heads
        self.head_dim = args.d_model // args.num_heads
        self.n_rep = self.num_heads // self.num_kv_heads

        self.Wq = nn.Linear(args.d_model, args.num_heads * self.head_dim)
        self.Wk = nn.Linear(args.d_model, args.num_kv_heads * self.head_dim)
        self.Wv = nn.Linear(args.d_model, args.num_kv_heads * self.head_dim)
        self.Wo = nn.Linear(args.num_heads * self.head_dim, args.d_model)

        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.num_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.num_kv_heads, self.head_dim))

    def forward(self, x: torch.tensor, start_pos: int, freqs_complex: torch.tensor):
        batch_size, seq_len, _ = x.shape
        
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)

        Q = apply_rope(Q, freqs_complex)
        K = apply_rope(K, freqs_complex)

        self.cache_k[:batch_size, start_pos : start_pos + seq_len] = K
        self.cache_v[:batch_size, start_pos : start_pos + seq_len] = V

        keys = self.cache_k[:batch_size, :start_pos + seq_len]
        values = self.cache_v[:batch_size, :start_pos + seq_len]

        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        Q = Q.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        scores = torch.matmul(Q, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim = -1).type_as(Q)
        output = torch.matmul(scores, values)
        output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
        return self.Wo(output)

In [7]:
class GLU(nn.Module):
    def __init__(self, in_size) -> None:
        super().__init__()
        self.linear1 = nn.Linear(in_size, in_size)
        self.linear2 = nn.Linear(in_size, in_size)
    
    def forward(self, x):
        return self.linear1(x) * torch.sigmoid(self.linear2(x))

In [8]:
class Swish(nn.Module):
    def __init__(self, beta: torch.tensor):
        super().__init__()
        self.beta = nn.Parameter(beta)

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)
    
class SwiGLU(nn.Module):
    def __init__(self, in_size, beta: torch.tensor) -> None:
        super().__init__()
        self.linear1 = nn.Linear(in_size, in_size)
        self.linear2 = nn.Linear(in_size, in_size)
        self.swish = Swish(beta)

    def forward(self, x):
        return self.linear1(x) * self.swish(self.linear2(x))

In [16]:
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        hidden_dim = 4 * args.d_model
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

        self.W1 = nn.Linear(args.d_model, hidden_dim)
        self.W2 = nn.Linear(hidden_dim, args.d_model)
        self.W3 = nn.Linear(args.d_model, hidden_dim)

    def forward(self, x:torch.tensor):
        swish = F.silu(self.W1(x))
        x_V = self.W3(x)
        x = swish * x_V
        x = self.W2(x)
        return x
        

In [20]:
class Encoder(nn.Module):
    def __init__(self, args: ModelArgs):
        self.num_heads = args.num_heads
        self.d_model = args.d_model
        self.head_dim = args.d_model // args.num_heads
        self.attention = GroupedQueryAttention(args)
        self.feed_forward = FeedForward(args)
        self.attention_norm = RMSNorm(args.d_model,eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.d_model, eps = args.norm_eps)

    def forward(self, x:torch.tensor, start_pos: int, freqs_complex: torch.tensor):
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
        output = h + self.feed_forward.forward(
            self.ffn_norm(x)
        )
        return output

In [22]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.embeddings = nn.Embedding(self.vocab_size, args.d_model)
        
        self.layers = nn.ModuleList()
        for layer in range(args.n_layers):
            self.layers.append(Encoder(args))

        self.norm = RMSNorm(args.d_model, eps =args.norm_eps)
        self.output = nn.Linear(args.d_model, self.vocab_size)

        self.freqs_complex = position_freqs(self.args.d_model // self.args.num_heads, self.args.max_seq_len * 2)

    def forward(self, tokens: torch.tensor, start_pos: int):
        batch_size, seq_len = tokens.shape
        assert seq_len == 1

        h = self.embeddings(tokens)
        freqs_complex = self.freqs_complex[start_pos: start_pos + seq_len]

        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float()
        return output