0.Prepare

In [1]:
import gc
import math
import os
from dataclasses import dataclass, field

import einops
import matplotlib.pyplot as plt
import rich
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import ByteLevel
from torch.utils.data import Dataset
from tqdm import tqdm

In [2]:
def get_device(verbose: bool = False):
    if torch.cuda.is_available():
        device = torch.device("cuda")
        if verbose:
            print("Using GPU")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        if verbose:
            print("Using Apple Silicon GPU")
    else:
        device = torch.device("cpu")
        if verbose:
            print("Using CPU")
    return device

def clear_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        
def tensor_to_device(*tensors, device="cpu", non_blocking=True):
    moved = tuple(t.to(device, non_blocking=non_blocking) for t in tensors)
    return moved if len(moved) > 1 else moved[0]


def print_color(text: str, color: str = "green"):
    rich.print(f"[{color}]{text}[/{color}]")

1.Transformer

1.1Model Config

In [3]:
@dataclass
class ModelConfig:
    vocab_size: int = 10_000
    max_seq_len: int = 128

    d_model: int = 512
    d_ff: int = 2048
    num_heads: int = 8
    num_layers: int = 6
    dropout: float = 0.1

2.2Word Embedding

In [4]:
class WordEmbedding(nn.Module):
    def __init__(self, vocab_size:int, embed_size:int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
    def forward(self, x:torch.Tensor)->torch.Tensor:
        return self.embedding(x)
        

2.3Position Embedding

In [5]:
class PositionalEmbedding(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        pos_index = torch.arange(0, config.max_seq_len)
        div_term = torch.exp(torch.arange(0, config.d_model,2) * -(math.log(10000.0))/config.d_model)
        
        pe = torch.zeros(config.max_seq_len, config.d_model).float()
        
        pe[:, 0::2] = torch.sin(pos_index.outer(div_term))
        pe[:, 1::2] = torch.cos(pos_index.outer(div_term))
        
        pe = pe.unsqueeze(0) # Shape: (1, max_seq_len, embed_size)
        self.register_buffer("pe", pe, persistent=False)
        
    def forward(self, x:torch.Tensor)->torch.Tensor:
        seq_len = x.size(1)
        return self.pe[:, :seq_len, :] # Shape: (1, seq_len, embed_size)
        
        

2.4Layer Normalization

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, dim:int, eps:float=1e-6):
        super().__init__()
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        
    def forward(self, x:torch.Tensor)->torch.Tensor:
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True, unbiased=False)
        
        x_hat = (x-mean)/torch.sqrt(var+self.eps)
        
        return self.gamma*x_hat+self.beta
        

2.5Feed Forward Network

In [7]:
class FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_ff)
        self.fc2 = nn.Linear(config.d_ff, config.d_model)
        
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

2.6Multi Head Attention

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config:ModelConfig, is_causal=False):
        super().__init__()
        self.num_heads = config.num_heads
        self.head_dim = config.d_model // config.num_heads
        self.is_causal = is_causal
        
        self.q_proj = nn.Linear(config.d_model, config.d_model)
        self.k_proj = nn.Linear(config.d_model, config.d_model)
        self.v_proj = nn.Linear(config.d_model, config.d_model)
        
        self.out_proj = nn.Linear(config.d_model, config.d_model)
        
        
    def construct_mask(self, pad_mask, q_len:int, k_len:int, device):
        #"True for valid positions, False for masked positions"
        mask = None
        
        # causal mask (decoder self-attention only)
        if self.is_causal:
            mask = torch.tril(torch.ones((q_len,k_len),device=device,dtype=torch.bool))[None, None,:,:]
            
        if pad_mask is not None:
            key_valid = (~pad_mask.bool())[:,None,None,:]
            mask = key_valid if mask is None else (mask & key_valid)
        
        
        return mask
    
    def forward(self, q, k, v, mask=None):
        device = q.device
        b, q_len, _ = q.shape
        _, kv_len,_ = k.shape
        mask = self.construct_mask(mask, q_len, kv_len, device)
        
        q = self.q_proj(q).view(b,q_len,self.num_heads,self.head_dim).transpose(1,2)
        k = self.k_proj(k).view(b,kv_len,self.num_heads,self.head_dim).transpose(1,2)
        v = self.v_proj(v).view(b,kv_len,self.num_heads,self.head_dim).transpose(1,2)
        
        logits = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(self.head_dim)
        
        if mask is not None:
            logits = logits.masked_fill(~mask,float("-inf"))
            
        score = F.softmax(logits,dim=-1)
        attn_output = torch.matmul(score,v)
        attn_output = einops.rearrange(attn_output,"batch heads seq_len d_k -> batch seq_len (heads d_k)",)
        
        attn_output = self.out_proj(attn_output)
        
        return attn_output
            
        
        

2.7Encoder Block

In [9]:
class EncoderBlock(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.attn = MultiHeadAttention(config, is_causal=False)
        self.ln1 = LayerNorm(config.d_model)
        
        self.ffn = FFN(config)
        self.ln2 = LayerNorm(config.d_model)
        
    def forward(self, x:torch.Tensor, mask=None)->torch.Tensor:
        attn_output = self.attn(x,x,x,mask=mask) # this means self attention
        x = self.ln1(x+attn_output)
        
        ffn_output = self.ffn(x)
        x = self.ln2(x+ffn_output)
        
        return x
        

2.8Decoder Block

In [10]:
class DecoderBlock(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(config, is_causal=False)
        self.ln1 = LayerNorm(config.d_model)
        
        self.cross_attn = MultiHeadAttention(config, is_causal=True)
        self.ln2 = LayerNorm(config.d_model)
        
        self.ffn = FFN(config)
        self.ln3 = LayerNorm(config.d_model)
        
    def forward(
            self, 
            x:torch.Tensor, 
            enc_output:torch.Tensor, 
            src_mask=None, 
            tgt_mask=None)->torch.Tensor:
        self_attn_output = self.self_attn(x,x,x,mask=tgt_mask)
        x = self.ln1(x+self_attn_output)
        
        cross_attn_output = self.cross_attn(x,enc_output,enc_output,mask=src_mask)
        x = self.ln2(x+cross_attn_output)
        
        ffn_output = self.ffn(x)
        x = self.ln3(x+ffn_output)
        
        return x
        

2.9Encoder and Decoder

In [11]:
class Encoder(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_layers)])
        
    def forward(self, x:torch.Tensor, mask=None)->torch.Tensor:
        for layer in self.layers:
            x = layer(x,mask)
            
        return x
    
class Decoder(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.layers = nn.ModuleList([DecoderBlock(config) for _ in range(config.num_layers)])
        
    def forward(
            self,
            x:torch.Tensor,
            enc_output:torch.Tensor,
            src_mask=None,
            tgt_mask=None)->torch.Tensor:
        for layer in self.layers:
            x = layer(x,enc_output,src_mask,tgt_mask)
        return x
        
        

2.10Full Model

In [14]:
class Transformer(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.vocab_embedding = WordEmbedding(config.vocab_size, config.d_model)
        self.positional_embedding = PositionalEmbedding(config)
        
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        
        self.output_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        self.apply(self._init_weights)
        self._tie_weights()
        
    def _tie_weights(self):
        self.output_proj.weight = self.vocab_embedding.embedding.weight # their shape is equal
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            nn.init.xavier_normal_(module.weight)
            
        elif isinstance(module, LayerNorm):
            nn.init.ones_(module.gamma)
            nn.init.zeros_(module.beta)
    
    def forward(
            self,
            src_input:torch.Tensor,
            tgt_input:torch.Tensor,
            src_mask=None,
            tgt_mask=None)->torch.Tensor:
        src_embeddings = self.vocab_embedding(src_input)*math.sqrt( # give the input a scale
            self.vocab_embedding.embedding.embedding_dim
        )+self.positional_embedding(src_input)
        
        tgt_embeddings = self.vocab_embedding(tgt_input)*math.sqrt(
            self.vocab_embedding.embedding.embedding_dim
        )+self.positional_embedding(tgt_input)
        
        enc_output = self.encoder(src_embeddings, mask=src_mask)
        
        dec_output = self.decoder(tgt_embeddings, enc_output,src_mask = src_mask, tgt_mask=tgt_mask)
        
        logits = self.output_proj(dec_output)
        
        return logits
        
            
        

2.11A Small Example

In [15]:
DEVICE = get_device(verbose=True)

src = torch.randint(0,10000,(2,10)).to(DEVICE)
tgt = torch.randint(0,10000,(2,10)).to(DEVICE)

config = ModelConfig()

model = Transformer(config).to(DEVICE)

logits = model(src,tgt)

print(logits.shape)

Using GPU
torch.Size([2, 10, 10000])
