In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import time
import math
import tiktoken
import os
from dataclasses import dataclass
from huggingface_hub import PyTorchModelHubMixin
from contextlib import nullcontext



from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group
from datatrove.utils.dataset import DatatroveFolderDataset
from transformers import AutoTokenizer

from typing import Tuple, Optional


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
enc = tiktoken.get_encoding("gpt2")
SEED = 1337

checkpoints_frequency = 2000


data_dir = "edu_fineweb10B"
log_dir = "log"
log_file = os.path.join(log_dir, f"log.txt")
val_log_file = os.path.join(log_dir, f"val_log.txt")

torch.manual_seed(SEED)
if device == 'cuda':
    torch.cuda.manual_seed(SEED)

# assert batch_size % (mini_batches * time_stamps) == 0, "batch_size is not devided by B and T"
# mini_epochs = int(batch_size / (mini_batches * time_stamps)) #number of mini-batches to get 0.5M batch


@dataclass
class ModelConfig:
    device: str
    vocab_size: int

    num_dims: int                       # number of dimensions
    num_heads: int                      # number of query heads
    num_kv_heads: int                   # number of key/value heads
    num_layers: int                     # total transformer layers
    ffn_hidden_dims: int                # hidden dimension for FFN/FFNwMoE

    context_len: int                    # maximum context length
    use_cache: bool                     # enable KV-caching
    use_flash: bool                     # use Flash Attention
    use_moe: bool                       # enable mixture-of-experts

    moe_num_experts: int                # total number of experts
    moe_routed_experts: int             # number of experts per token (top_k)
    moe_eps: float = 1e-6               # epsilon for router stability
    moe_aux_loss_coef: float = 0.01     # coefficient for auxiliary loss
    moe_shared_experts: int = 0         # number of shared experts (DeepSeekMoE)
    use_lossfreebalance: bool = False   # use Auxiliary-loss-free load balancing strategy for mixture-of-experts from DeepSeek https://arxiv.org/pdf/2408.15664

    rmsnorm_eps: float = 1e-6
    rope_theta: float = 1e5

    ffn_dim_multiplier: Optional[int] = None    # optional multiplier to compute ffn_hidden_dims





In [5]:

# Helper function for RoPE
def repeat_kv(vct: torch.Tensor, n_times: int):
    c_batch_size, c_context_len, num_kv_heads, c_dim = vct.shape
    if n_times == 1:
        return vct
    else:
        return (
            vct[:, :, :, None, :]
            .expand(c_batch_size, c_context_len, num_kv_heads, n_times, c_dim)
            .reshape(c_batch_size, c_context_len, num_kv_heads * n_times, c_dim)
        )


def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "dimensions must be divisible by 2"

    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
    m = torch.arange(seq_len, device=device)

    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_pos(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex * freqs_complex

    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

In [53]:

class RMSNorm(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.g = nn.Parameter(torch.ones(config.num_dims))
        self.eps = config.rmsnorm_eps
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.g * self._norm(x.float()).type_as(x)
    

class GroupedQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.use_cache = config.use_cache
        self.use_flash = config.use_flash

        self.num_heads = config.num_heads
        self.num_kv_heads = config.num_heads if config.num_kv_heads is None else config.num_kv_heads

        self.num_rep = self.num_heads // self.num_kv_heads
        self.head_dim = config.num_dims // self.num_heads

        self.wq = nn.Linear(config.num_dims, config.num_dims, bias=False)
        self.wk = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(config.num_dims, config.num_dims, bias=False)

    def forward(self, x, freqs_complex, start_pos = 0):
        c_batch_size, c_context_len, c_dim = x.shape # c_context_len = 1
    
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim)      # B, T, qh, hs
        k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim)   # B, T, kh, hs
        v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim)   # B, T, vh, hs

        queries = apply_rotary_pos(q, freqs_complex, device=x.device)
        keys = apply_rotary_pos(k, freqs_complex, device=x.device)

        if self.use_cache:
            # Initialize cache if not exist
            if self.cache_k is None:
                self.cache_k = torch.zeros(
                    (c_batch_size, self.config.context_len, self.num_kv_heads, self.head_dim),
                    device=x.device
                )
                self.cache_v = torch.zeros(
                    (c_batch_size, self.config.context_len, self.num_kv_heads, self.head_dim),
                    device=x.device
                )
            # Update cache
            self.cache_k[:c_batch_size, start_pos:start_pos + c_context_len] = keys
            self.cache_v[:c_batch_size, start_pos:start_pos + c_context_len] = v

            keys = self.cache_k[:c_batch_size, :start_pos + c_context_len]
            v = self.cache_v[:c_batch_size, :start_pos + c_context_len]
            

        if self.use_flash:
            output = F.scaled_dot_product_attention(queries, keys, v, is_causal=True, enable_gqa=True)
            
        else: # Calculate Grouped Query Attention manually
            keys = repeat_kv(keys, self.num_rep)
            values = repeat_kv(v, self.num_rep)
    
            queries = queries.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
    
            attention = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
    
            attention = torch.tril(attention[:, :, :c_context_len, :c_context_len])
            attention = attention.masked_fill(attention == 0, float("-inf"))
    
            attention = F.softmax(attention, dim=-1).type_as(queries)
            output = torch.matmul(attention, values)

        output = output.transpose(2, 1).contiguous().view(c_batch_size, c_context_len, c_dim)
        return self.wo(output)


class FeedForward(nn.Module):
    """
    Default Feed Forward Layer.
    """
    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.ffn_hidden_dims

        self.w1 = nn.Linear(config.num_dims, self.hidden_dim, bias=False)
        self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False)
        self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        return self.w2(F.silu(self.w1(x)) * self.w3(x)), None


class FFNwMoE(nn.Module): 
    """
    Feed Forward with MoE with optional shared experts.
    Returns after forward:
        output: Combined outputs from experts
        aux_loss: Auxiliary loss tensor or routing metadata
    """
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.hidden_dim = config.ffn_hidden_dims

        self.moe_routed_experts = config.moe_routed_experts # top_k
        self.moe_aux_loss_coef = config.moe_aux_loss_coef
        self.moe_eps = config.moe_eps
        self.moe_shared_experts = config.moe_shared_experts
        self.num_experts = config.moe_num_experts

        self.use_lossfreebalance = config.use_lossfreebalance 


        self.router = nn.Linear(config.num_dims, self.num_experts, bias=False)
        self.experts = nn.ModuleList()
        for _ in range(self.num_experts):
            self.experts.append(
                nn.ModuleList([
                    nn.Linear(config.num_dims, self.hidden_dim, bias=False),
                    nn.Linear(self.hidden_dim, config.num_dims, bias=False),
                    nn.Linear(config.num_dims, self.hidden_dim, bias=False)
                ]))
        
        # shared experts (for DeepSeekMoE)
        self.shared_experts = nn.ModuleList()
        for _ in range(self.moe_shared_experts):
            self.shared_experts.append(
                nn.ModuleList([
                    nn.Linear(config.num_dims, self.hidden_dim, bias=False),
                    nn.Linear(self.hidden_dim, config.num_dims, bias=False),
                    nn.Linear(config.num_dims, self.hidden_dim, bias=False)
                ]))
            
        # Auxiliary-loss-free load balancing strategy for mixture-of-experts from DeepSeek https://arxiv.org/pdf/2408.15664
        if self.use_lossfreebalance:
            self.expert_biases = nn.Parameter(torch.zeros(self.num_experts))
            
    def forward(self, x: torch.Tensor):
        c_batch_size, c_context_len, c_dim = x.shape
        x_flat = x.view(-1, c_dim)          #c_batch_size * c_context_len, c_dim

        router_out = self.router(x_flat)
        router_probs = F.softmax(router_out, dim=-1) 

        _, topk_indices = router_out.topk(self.moe_routed_experts, dim=-1)

        aux_loss, topk_probs = self._compute_aux_loss(router_out, router_probs, topk_indices)

        output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs)

        return output.view(c_batch_size, c_context_len, c_dim), aux_loss

    def _compute_aux_loss(self, router_out, router_probs, topk_indices):
        """
        Computes the auxiliary loss based on whether loss-free balancing is used or not.
        """
        if not self.use_lossfreebalance:
            topk_probs, _ = router_probs.topk(self.moe_routed_experts, dim=-1)
            expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float()
            density = expert_mask.mean(dim=0)
            router_prob_mean = router_probs.mean(dim=0)
            aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts

        else: # if use_lossfreebalance
            router_out = router_out + self.expert_biases
            router_probs = torch.sigmoid(router_out) # from https://arxiv.org/pdf/2408.15664 paper
            topk_probs = router_probs.gather(-1, topk_indices)
            topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)

            # In the case of Auxiliary-loss-free load balancing we pass router_probs, topk_indices as aux_loss for further calculations 
            aux_loss = (router_probs, topk_indices)
        return aux_loss, topk_probs

    def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs):
        """
        Compute the output of the experts and shared experts if needed
        """
        output = torch.zeros_like(x_flat)

        for i in range(self.moe_routed_experts):
            expert_index = topk_indices[:, i]
            expert_probs = topk_probs[:, i]

            for expert_id in range(self.num_experts):
                idx = (expert_id == expert_index).nonzero().squeeze()

                if idx.numel() == 0:
                    continue
                x_for_expert = x_flat[idx]
                w1, w2, w3 = self.experts[expert_id]
                
                expert_output = w2(F.silu(w1(x_for_expert)) * w3(x_for_expert))
                output[idx] += expert_output * expert_probs[idx].unsqueeze(-1)

        # shared experts(for DeepSeekMoE)
        for shared_expert_id in range(self.moe_shared_experts):
            w1, w2, w3 = self.shared_experts[shared_expert_id]
            expert_output = w2(F.silu(w1(x_flat)) * w3(x_flat))
            output = output + expert_output
        
        return output


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.attention = GroupedQueryAttention(config)
        if config.use_moe:
            self.ffn = FFNwMoE(config)
        else:
            self.ffn = FeedForward(config)


        self.norm_attention = RMSNorm(config)
        self.norm_ffn = RMSNorm(config)

    def forward(self, x, freqs_complex, start_pos):
        x = x + self.attention(
            self.norm_attention(x), 
            freqs_complex, 
            start_pos
            )
        
        ffn_out, aux_loss = self.ffn(
            self.norm_ffn(x)
            )
        x = x + ffn_out
        return x, aux_loss
    


In [None]:

class Transformer(nn.Module, PyTorchModelHubMixin): # extending PyTorchModelHubMixin for save weights as safetensors
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.vocab_size = config.vocab_size
        self.num_dims = config.num_dims
        self.num_heads = config.num_heads
        self.num_layers = config.num_layers
        self.context_len = config.context_len
        self.use_moe = config.use_moe

        self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe

        # Calculation of hidden_dim for FFN/FFNwMoE
        multiple_of = 4
        ffn_dim_multiplier = config.ffn_dim_multiplier
        hidden_dim = 4 * config.num_dims
        hidden_dim = int(2 * config.num_dims / 3)

        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)

        config.ffn_hidden_dims = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.tokens_embedding = nn.Embedding(self.vocab_size, self.num_dims)

        self.blocks = nn.ModuleList()
        for _ in range(self.num_layers):
            self.blocks.append(Block(config))

        self.norm = RMSNorm(config)
        self.ll_head = nn.Linear(self.num_dims, self.vocab_size, bias=False)

        self.tokens_embedding.weight = self.ll_head.weight

        self.freqs_complex = precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=config.device)




    def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None, start_pos: int = 0):
        _, seq_len = x.shape
        
        x = self.tokens_embedding(x)
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
        
        total_aux_loss = 0

        for block in self.blocks:
            x, aux_loss = block(x, freqs_complex=freqs_complex, start_pos=start_pos)
            if self.use_moe and not self.use_lossfreebalance:
                total_aux_loss += aux_loss

        x = self.norm(x)
        logits = self.ll_head(x)
        
        
        if targets is None:
            loss = None
            ce_loss = None
        else:
            c_batch_size, c_context_len, c_dim = logits.shape
            logits = logits.view(c_batch_size*c_context_len, c_dim)
            targets = targets.view(c_batch_size*c_context_len)
            ce_loss = F.cross_entropy(logits, targets)
            
            if self.use_moe and not self.use_lossfreebalance: loss = ce_loss + total_aux_loss    # in this case, ce_loss its loss w/o aux_loss
            else: # if we want to use Auxiliary-loss-free load balancing we pass router_probs, topk_indices as ce_loss
                # Also, work when moe is not used
                loss = ce_loss
                ce_loss = aux_loss

        return logits, loss, ce_loss

    @torch.no_grad()
    def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50, 
                 use_cache: bool = False):
        """
        Generate text from x up to max_tokens
        """
        for c_tkn_pos in range(max_tokens):
            if use_cache:
                if c_tkn_pos == 0:
                    logits, _, ce_loss = self.forward(x, start_pos=c_tkn_pos)
                else:
                    logits, _, ce_loss = self.forward(x[:, -1], start_pos=c_tkn_pos)
            else:
                logits, _, ce_loss = self.forward(x)

            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                tkl, idx = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < tkl[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, next_token), dim=1)
        return x
    

In [None]:

@dataclass
class TrainerConfig:
    vocab_size: int
    num_epochs: int

    use_ddp: bool
    use_moe: bool
    use_lossfreebalance: bool
    clean_cuda_cache: bool = True  # Helps prevent OOM errors during eval on large models
    use_compile: bool = True

    seed: int = 1998
    max_seq_len: int = 1024
    batch_size: int = 1
    accumulation_steps: int = 1
    
    weight_decay: float = 0.1
    warmup_ratio: float = 0.01
    learning_rate: float = 1e-3
    betas: Tuple[float, float] = (0.90, 0.95)

    val_ratio: int = 0.005
    steps_for_eval: int = 20
    eval_interval: int = 50

    checkpoints_frequency: int = 500
    path_to_checkpoints: str = "./model_testing"

    tokenized_dataset_path: str = "wiki_hindi_tok/"
    eval_eval_log_file: str = "logs/eval.txt"



class DataLoader():
    def __init__(self, config, rank=0, world_size=1):
        self.config = config
        self.current_epoch = 0
        self.seed = config.seed
        self.token_size = 2 if config.vocab_size < 65535 else 4

        self.load_dataset(self.seed)
        self.len_dataset = len(self.dataset)

        if rank == 0:
            print(f"{'Total tokens loaded: '} {self.len_dataset * config.max_seq_len:,}")

        self.train_len_dataset = math.ceil((1-config.val_ratio) * self.len_dataset)
        self.val_len_dataset = self.len_dataset - self.train_len_dataset

        shard_size = self.len_dataset // world_size 
        self.train_start_idx = rank * shard_size
        self.train_end_idx = self.train_start_idx + shard_size
        self.train_current_idx = self.train_start_idx

        self.val_start_idx = self.train_len_dataset
        self.val_current_idx = self.val_start_idx

    def get_batch(self, current_idx: int, start_idx: int, end_idx: int):
        new_idx = current_idx + self.config.batch_size
        
        x_l, y_l = zip(*[(self.dataset[idx]['input_ids'][:-1], self.dataset[idx]['input_ids'][1:])
                    for idx in range(current_idx, min(new_idx, self.len_dataset))])
        x, y = torch.stack(list(x_l)), torch.stack(list(y_l))
    
        if new_idx >= end_idx:
            new_idx = start_idx
            self.new_epoch()

        return x, y, new_idx

    def next_batch(self, split):
        if split == "train":
            x, y, self.train_current_idx = self.get_batch(self.train_current_idx, self.train_start_idx, self.train_end_idx)
        else: # validation
            x, y, self.val_current_idx = self.get_batch(self.val_current_idx, self.val_start_idx, self.len_dataset)
        return x, y
    
    def reset(self, rank: int = 0, world_size: int = 1):
        self.current_epoch = 0
        self.seed = self.config.seed
        self.load_dataset(self.seed)
        self.len_dataset = len(self.dataset)

        self.val_len_dataset = self.len_dataset - self.train_len_dataset

        shard_size = self.len_dataset // world_size 
        self.train_start_idx = rank * shard_size
        self.train_end_idx = self.train_start_idx + shard_size
        self.train_current_idx = self.train_start_idx

        self.val_start_idx = self.train_len_dataset
        self.val_current_idx = self.val_start_idx

    def new_epoch(self):
        self.current_epoch += 1
        self.load_dataset(self.seed + self.current_epoch)

    def load_dataset(self, seed: int):
        self.dataset = DatatroveFolderDataset(
            folder_path=self.config.tokenized_dataset_path,
            filename_pattern=os.path.join(self.config.tokenized_dataset_path, "**", "*.ds"),
            seq_len=self.config.max_seq_len,
            token_size=self.token_size,
            recursive=True,
            shuffle=True,
            seed=seed
        )

    def num_train_steps(self):
        return math.ceil((self.train_end_idx-self.train_start_idx) / self.config.batch_size)


class Trainer():
    def __init__(self, config, model, tokenizer):
        self.config = config
        self.model = model
        self.num_epochs = config.num_epochs

        self.use_moe = config.use_moe
        self.use_lossfreebalance = config.use_lossfreebalance if self.use_moe else False
        self.clean_cuda_cache = config.clean_cuda_cache
        self.steps_for_eval = config.steps_for_eval
        self.weight_decay = config.weight_decay
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device == 'cuda':
            torch.cuda.manual_seed(config.seed)
            n_gpus = torch.cuda.device_count()

        use_compile = self.config.use_compile and self.device == "cuda" and torch.__version__.startswith("2")
        if use_compile:
            self.model = torch.compile(self.model)
            
        # DDP
        if n_gpus > 1 and config.use_ddp:
            self.ddp = True
            init_process_group(backend="nccl")
            self.ddp_rank = int(os.environ['RANK'])
            self.ddp_local_rank = int(os.environ['LOCAL_RANK'])
            self.ddp_world_size = int(os.environ['WORLD_SIZE'])
            self.device = f'cuda:{self.ddp_local_rank}'
            torch.cuda.set_device(self.device)
            self.master_process = self.ddp_rank == 0
            
            self.model = DDP(self.model, device_ids=[self.ddp_local_rank])
            self.raw_m = self.model.module
        else:
            self.ddp = False
            self.ddp_rank = 0
            self.ddp_world_size = 1
            self.master_process = True
            # self.model.to(self.device)

        print("Device:", self.device)
        print(f"Model's trainable params: {sum([p.data.numel() for p in self.model.parameters() if p.requires_grad]) / 1e6:.2f}M")
        print(f"Tokens per step: {self.config.batch_size * self.config.max_seq_len * self.ddp_world_size * self.config.accumulation_steps}")
        print(f"use {'torch.compile()'}: {use_compile}")
        print(f"Use MoE: {'Yes ' if self.use_moe else 'No'}")
        if self.use_moe:
            print(f"Number of experts: {self.model.blocks[0].ffn.num_experts}")
            print(f"Number of used experts during inference: {self.model.blocks[0].ffn.moe_routed_experts}")
            print(f"Method of aux_loss: {'loss-free-balance' if config.use_lossfreebalance else 'default'}")
            print(f"Number of parameters will be used during inference: {((sum([p.data.numel() for p in self.model.parameters() if p.requires_grad]) - sum(p.numel() for p in self.model.blocks[0].ffn.parameters()) * len(self.model.blocks) * (self.model.blocks[0].ffn.moe_routed_experts + self.model.blocks[0].ffn.moe_routed_experts) / self.model.blocks[0].ffn.num_experts))  / 1e6:.2f}M")
    
    def step(self, data_loader, accumulation_steps: int,
              num_tokens: int, split: str = "train"):
        """
        Performs single forward/backward pass with gradient accumulation.
            Returns: (total_loss, cross_entropy_loss, number_of_processed_tokens)
        """
        x, y = data_loader.next_batch(split=split)
        x, y = x.to(self.device), y.to(self.device)
        num_tokens += torch.numel(x)
        with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
            _, loss, ce_loss = self.model(x, y)
        loss /= accumulation_steps

        loss.backward()
        return loss, ce_loss, num_tokens
    

    def train(self, data_loader):
        num_steps_per_epoch = math.ceil(data_loader.num_train_steps() / self.config.accumulation_steps)

        # Configuration of optimizer and schedulers
        # Using AdamW with cosine decay and warmup - similar to Llama's training setup
        optimizer = torch.optim.AdamW(
            self.model.parameters(),  
            lr=self.config.learning_rate,
            betas=self.config.betas,
            weight_decay=self.weight_decay,
            fused=(self.device=="cuda")
        )
        
        warmup_steps = math.floor(self.config.warmup_ratio * num_steps_per_epoch * self.num_epochs)
        warmup_factor = lambda step: 0.05 + 0.95 * (step / max(warmup_steps, 1))
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=warmup_factor
        )

        cos_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=(num_steps_per_epoch * self.num_epochs) - warmup_steps, 
            eta_min=0.1 * self.config.learning_rate
        )
        
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cos_scheduler],
            milestones=[warmup_steps])

        last_step = num_steps_per_epoch - 1
        self.model.train()

        for epoch in range(self.num_epochs):
            for step in range(num_steps_per_epoch):
                t0 = time.perf_counter()
                accumulated_loss = 0.0
                num_tokens = 0

                ddp_nosync_ctx = self.model.no_sync() if self.ddp else nullcontext()
                with ddp_nosync_ctx:
                    for _ in range(self.config.accumulation_steps - 1):
                        loss, ce_loss, num_tokens = self.step(data_loader, self.config.accumulation_steps, num_tokens, split="train")
                        accumulated_loss += loss

                loss, ce_loss, num_tokens = self.step(data_loader, self.config.accumulation_steps, num_tokens, split="train")
                accumulated_loss += loss.detach()

                # Calculate expert biases using Auxiliary Loss-Free Balance method for MoE (https://arxiv.org/pdf/2408.15664)
                if self.use_moe and self.use_lossfreebalance: 
                    for block in range(len(self.model.blocks)):
                        expert_counts = torch.bincount(ce_loss[1].flatten(), minlength=self.model.blocks[block].ffn.moe_routed_experts)  
                        avg_count = expert_counts.float().mean()
                        for i, count in enumerate(expert_counts):
                            error = avg_count - count.float()
                            self.model.blocks[block].ffn.expert_biases.data[i] += self.update_rate * torch.sign(error)

                norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) #ToDO

                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

                t1 = time.perf_counter()

                tokens_per_sec = num_tokens / (t1 - t0)

                # Logging 
                if self.master_process:
                    print(f"Epoch: {epoch} | Step: {step} |  loss: {accumulated_loss:.4f} | norm: {norm:.4f} |  tok/s: {tokens_per_sec}")
                
                # Evaluation 
                if self.master_process and ((step>0 and step % self.config.eval_interval == 0) or step == last_step):
                    self.model.eval() 
                    val_loss = self.eval(data_loader)

                    with open(self.config.eval_log_file, "a") as f:
                        f.write(f"Step: {step * (epoch+1)}, val_loss: {val_loss:.4f}, norm: {norm:.4f}, time: {t1 - t0:.2f}ms, tok/s: {tokens_per_sec:.1f} \n")

                    self.model.train()
                    if self.clean_cuda_cache:
                        torch.cuda.empty_cache()

                # Save Chekpoints
                if self.master_process and ((step % self.config.checkpoints_frequency == 0 and step > 0) or step == last_step):
                    self.save_checkpoints(self.config.path_to_checkpoints, name=str((epoch+1) * step))

    def eval(self, data_loader):
        """
        Evaluates model on validation split using running average of first [steps_for_eval] batches
        """
        with torch.no_grad():
            val_loss_accum = 0.0
            for _ in range(self.steps_for_eval):
                with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
                    x, y = data_loader.next_batch(split="val")
                    x, y = x.to(self.device), y.to(self.device)
                    with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
                        _, loss, ce_loss = self.model(x, y)
                    loss /= self.steps_for_eval
                    val_loss_accum += loss.detach()
            return val_loss_accum

    def save_checkpoints(self, path: str, name: str):
        os.makedirs(path, exist_ok=True)
        checkpoint_path = os.path.join(path, f"model.checkpoint.{name}.pt")
        # self.model.save_pretrained(".checkpoint_path", config=config)
        torch.save(self.model.state_dict(), checkpoint_path)
        print("Checkpoint saved")

In [None]:
torch.set_float32_matmul_precision('high')

train_config = TrainerConfig()


config = ModelConfig(
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    vocab_size = 50304,

    num_dims = 1024,
    num_heads = 16,
    num_kv_heads = 4,
    num_layers = 16,
    ffn_hidden_dims = 1024 * 4,

    rmsnorm_eps = 1e-6,
    rope_theta = 1e5,

    context_len = 1024,
    
    use_cache = False,
    use_flash = False,
    use_moe = False,

    moe_num_experts = 6,
    moe_routed_experts = 1,
    moe_eps = 1e-6,
    moe_aux_loss_coef = 0.01,
    moe_shared_experts = 0,
    use_lossfreebalance = False,

)

model = Transformer(config)
model = model.to(device)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

tokenizer_id = "HuggingFaceTB/SmolLM-360M"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

data_loader = DataLoader(train_config)
trainer = Trainer(train_config, model, tokenizer)
trainer.train(data_loader)