In [41]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random
import time
import math
import tiktoken
import inspect
import os
from dataclasses import dataclass



from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

In [42]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
enc = tiktoken.get_encoding("gpt2")
SEED = 1337

checkpoints_frequency = 2000


data_dir = "edu_fineweb10B"
log_dir = "log"
log_file = os.path.join(log_dir, f"log.txt")
val_log_file = os.path.join(log_dir, f"val_log.txt")

torch.manual_seed(SEED)
if device == 'cuda':
    torch.cuda.manual_seed(SEED)

# assert batch_size % (mini_batches * time_stamps) == 0, "batch_size is not devided by B and T"
# mini_epochs = int(batch_size / (mini_batches * time_stamps)) #number of mini-batches to get 0.5M batch


@dataclass
class ModelConfig:
    vocab_size: int

    num_dims: int
    num_heads: int
    num_kv_heads: int
    num_layers: int
    ffn_hidden_dims: int

    batch_size: int
    mini_batches: int
    time_stamps: int
    context_len: int
    use_cache: bool

    num_experts: int
    moe_topk: int
    moe_eps: float = 1e-6
    moe_aux_loss_coef: float= 0.007

    rmsnorm_eps: float = 1e-6
    rope_theta: float = 1e5





In [43]:
ddp = int(os.environ.get('RANK', -1)) != -1 
if ddp:
    init_process_group(backend="nccl")
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"device: {device}")
    

def get_lr(epoch, warmup_lr_steps, max_lr, min_lr, epochs):
    if epoch < warmup_lr_steps:
        return (max_lr * (epoch+1)/warmup_lr_steps)
    if epoch > epochs:
        return min_lr
    loc = (epoch - warmup_lr_steps)/(epochs - warmup_lr_steps)
    coef = 0.5 * (1.0 + math.cos(math.pi * loc))
    return min_lr + coef * (max_lr - min_lr)

device: cuda


In [44]:
def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32)
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

class DataLoader():
    def __init__(self, B, T, cur_process, num_processes, data_dir, split):
        self.B = B
        self.T = T
        self.cur_process = cur_process
        self.cur_shard = 0
        self.num_processes = num_processes
        self.data_dir = data_dir

        shards = os.listdir(self.data_dir)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(self.data_dir, s) for s in shards]
        self.shards = shards

        self.tokens = load_tokens(self.shards[self.cur_shard])
        
        self.current_step = cur_process * B * T

        print(f"loaded ~{len(self.tokens)*len(self.shards)} tokens")


    def reset(self):
        self.cur_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.cur_process
        
    def next_batch(self):
        B, T = self.B, self.T
        
        self.current_step += B * T * self.num_processes
        tokens = self.tokens[self.current_step:self.current_step+B*T+1]
        x = (tokens[:-1]).view(B, T)
        y = (tokens[1:]).view(B, T)
        if (self.current_step+B*T* self.num_processes + B*T+1)  > len(self.tokens):
            self.cur_shard = (self.cur_shard+1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.cur_shard])
            self.current_step = self.cur_process * B * T
        return x, y

In [45]:
def repeat_kv(vct: torch.Tensor, n_times: int):
    bsz, cl, num_kv_heads, dm = vct.shape
    if n_times == 1:
        return vct
    else:
        return (
            vct[:, :, :, None, :]
            .expand(bsz, cl, num_kv_heads, n_times, dm)
            .reshape(bsz, cl, num_kv_heads * n_times, dm)
        )


# https://github.com/hkproj/pytorch-llama/blob/main/model.py
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_pos(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

In [46]:
class RMSNorm(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.g = nn.Parameter(torch.ones(config.num_dims))
        self.eps = config.rmsnorm_eps
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.g * self._norm(x.float()).type_as(x)
    

class GroupedQueryAttention(nn.Module):
    def __init__(self, config, device=device):
        super().__init__()

        self.use_cache = config.use_cache

        self.num_heads = config.num_heads
        self.num_kv_heads = config.num_heads if config.num_kv_heads is None else config.num_kv_heads

        self.num_heads = config.num_heads
        self.num_rep = self.num_heads // self.num_kv_heads
        self.head_dim = config.num_dims // self.num_heads

        self.wq = nn.Linear(config.num_dims, config.num_dims, bias=False)
        self.wk = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(config.num_dims, config.num_dims, bias=False)

        # self.cache_k = torch.zeros(
        #     (
        #         config.batch_size,
        #         config.context_len,
        #         self.num_kv_heads,
        #         self.head_dim
        #     ), device=device
        # )

        # self.cache_v = torch.zeros(
        #     (
        #         config.batch_size,
        #         config.context_len,
        #         self.num_kv_heads,
        #         self.head_dim
        #     ), device=device
        # )

    def forward(self, x, freqs_complex, start_pos = 0):
        c_batch_size, c_context_len, c_dim = x.shape # c_context_len = 1
    
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim) # B, T, qh, hs
        k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim) # B, T, kh, hs
        v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim) # B, T, vh, hs

        queries = apply_rotary_pos(q, freqs_complex, device=x.device)
        keys = apply_rotary_pos(k, freqs_complex, device=x.device)

        # if self.use_cache:
        #     self.cache_k[:c_batch_size, start_pos:start_pos+c_context_len] = keys
        #     self.cache_v[:c_batch_size, start_pos:start_pos+c_context_len] = v
            
        #     keys = self.cache_k[:c_batch_size, :start_pos+c_context_len]
        #     v = self.cache_v[:c_batch_size, :start_pos+c_context_len]

        keys = repeat_kv(keys, self.num_rep)
        values = repeat_kv(v, self.num_rep)

        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        attention = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))

        attention = torch.tril(attention[:, :, :c_context_len, :c_context_len])
        attention = attention.masked_fill(attention == 0, float("-inf"))

        attention = F.softmax(attention, dim=-1).type_as(queries)
        output = torch.matmul(attention, values)

        output = output.transpose(2, 1).contiguous().view(c_batch_size, c_context_len, c_dim)
        return self.wo(output)


class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        multiple_of=4
        ffn_dim_multiplier=None
        hidden_dim = 4 * config.num_dims
        hidden_dim = int(2 * config.num_dims / 3)

        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)

        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(config.num_dims, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config.num_dims, bias=False)
        self.w3 = nn.Linear(config.num_dims, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class FFNwMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.moe_aux_loss_coef = config.moe_aux_loss_coef
        self.moe_eps = config.moe_eps

        multiple_of=4
        ffn_dim_multiplier=None
        hidden_dim = 4 * config.num_dims
        hidden_dim = int(2 * config.num_dims / 3)

        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)

        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.hidden_dim = hidden_dim

        self.top_k = config.moe_topk

        self.num_experts = config.num_experts
        self.router = nn.Linear(config.num_dims, self.num_experts, bias=False)
        self.experts = nn.ModuleList()
        for _ in range(self.num_experts):
            self.experts.append(
                nn.ModuleList([
                    nn.Linear(config.num_dims, hidden_dim, bias=False),
                    nn.Linear(hidden_dim, config.num_dims, bias=False),
                    nn.Linear(config.num_dims, hidden_dim, bias=False)
                ]))
            
    def forward(self, x):
        c_batch_size, c_context_len, c_dim = x.shape
        x_flat = x.view(-1, c_dim) #c_batch_size * c_context_len, c_dim


        router_out = F.softmax(self.router(x_flat), dim=-1)
        topk_probs, topk_indices = router_out.topk(self.top_k, dim=-1)
        
        output = torch.zeros_like(x_flat)

        # expert_mask = torch.zeros_like(router_out).scatter(-1, topk_indices, 1.0)
        # density = expert_mask.mean(dim=0)
        # router_prob_mean = router_out.mean(dim=0)

        # cv_expert = density.std() / (density.mean() + self.moe_eps)
        # cv_router = router_prob_mean.std() / (router_prob_mean.mean() + self.moe_eps)
        
        # aux_loss = (cv_expert**2 + cv_router**2) * self.moe_aux_loss_coef

        expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float()
        density = expert_mask.mean(dim=0)
        router_prob_mean = router_out.mean(dim=0)
        aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts
        
        for i in range(self.top_k):
            expert_index = topk_indices[:, i]
            expert_probs = topk_probs[:, i]

            for expert_id in range(self.num_experts):
                idx = (expert_id == expert_index).nonzero().squeeze()

                if idx.numel() == 0:
                    continue
                x_for_expert = x_flat[idx]
                w1, w2, w3 = self.experts[expert_id]
                
                expert_output = w2(F.silu(w1(x_for_expert)) * w3(x_for_expert))
                output[idx] += expert_output * expert_probs[idx].unsqueeze(-1)

        return output.view(c_batch_size, c_context_len, c_dim), aux_loss



class Block(nn.Module):
    def __init__(self, config, device=device):
        super().__init__()

        self.attention = GroupedQueryAttention(config, device=device)
        self.ffn = FFNwMoE(config)

        self.norm_attention = RMSNorm(config)
        self.norm_ffn = RMSNorm(config)

    def forward(self, x, freqs_complex, start_pos):
        x = x + self.attention(
            self.norm_attention(x), 
            freqs_complex, 
            start_pos
            )
        
        ffn_out, aux_loss = self.ffn(
            self.norm_ffn(x)
            )
        x = x + ffn_out
        return x, aux_loss

In [47]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.vocab_size = config.vocab_size
        self.num_dims = config.num_dims
        self.num_heads = config.num_heads
        self.num_layers = config.num_layers
        self.context_len = config.context_len

        self.tokens_embedding = nn.Embedding(self.vocab_size, self.num_dims)

        self.blocks = nn.ModuleList()
        for _ in range(self.num_layers):
            self.blocks.append(Block(config))

        self.norm = RMSNorm(config)
        self.ll_head = nn.Linear(self.num_dims, self.vocab_size, bias=False)

        self.tokens_embedding.weight = self.ll_head.weight

        self.freqs_complex = precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=device)

    # I have taken this function [configure_optimizers] from Karpathy's nanoGPT
    # https://github.com/karpathy/nanoGPT
    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def forward(self, x, targets=None, start_pos=0):
        _, seq_len = x.shape
        
        x = self.tokens_embedding(x)
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
        
        total_aux_loss = 0


        for block in self.blocks:
            x, aux_loss = block(x, freqs_complex=freqs_complex, start_pos=start_pos)
            total_aux_loss += aux_loss

        x = self.norm(x)
        logits = self.ll_head(x)
        
        
        if targets is None:
            loss = None
        else:
            c_batch_size, c_context_len, c_dim = logits.shape
            logits = logits.view(c_batch_size*c_context_len, c_dim)
            targets = targets.view(c_batch_size*c_context_len)
            tt_loss = F.cross_entropy(logits, targets)
            loss = tt_loss + total_aux_loss

        return logits, loss, tt_loss

    @torch.no_grad()
    def generate(self, x, max_tokens, use_cache=False):
        for c_tkn_pos in range(max_tokens):
            if use_cache:
                if c_tkn_pos == 0:
                    logits, _ = self.forward(x, start_pos=c_tkn_pos)
                else:
                    logits, _ = self.forward(x[:, -1], start_pos=c_tkn_pos)
            else:
                logits, _ = self.forward(x)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, next_token), dim=1)
        return x

In [50]:

# def main():
config = ModelConfig(
    vocab_size = 50304,

    num_dims = 1024,
    num_heads = 16,
    num_kv_heads = 4,
    num_layers = 16,
    ffn_hidden_dims = 4 * 1024,

    rmsnorm_eps = 1e-6,
    rope_theta = 1e5,

    batch_size = 2**19,
    mini_batches = 2,
    time_stamps = 512,
    context_len = 1024,
    
    use_cache = False,

    num_experts = 6,
    moe_topk = 2,
    moe_eps = 1e-6,
    moe_aux_loss_coef = 0.007
)

m = Transformer(config)
model = m.to(device)
model = torch.compile(model)
# if ddp:
#     model = DDP(model, device_ids=[ddp_local_rank]) 
# raw_m = model.module if ddp else model

mini_epochs = int(2**19 / (16 * 512))
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

data_loader = DataLoader(config.mini_batches, config.time_stamps, cur_process=ddp_rank, num_processes=ddp_world_size, data_dir=data_dir, split="train")
# val_loader = DataLoader(config.mini_batches, config.time_stamps, cur_process=ddp_rank, num_processes=ddp_world_size, data_dir=data_dir, split="val")

max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_lr_steps = 700
weight_decay = 0.1
beta1, beta2 = 0.9, 0.95

optmizer = model.configure_optimizers(weight_decay, max_lr, (beta1, beta2), device)


# if __name__ == "__main__":
#     main()

127.108096 M parameters
loaded ~400000000 tokens
num decayed parameter tensors: 113, with 127,074,304 parameters
num non-decayed parameter tensors: 33, with 33,792 parameters
using fused AdamW: True


In [36]:
# test
# x, y = data_loader.next_batch()
# x, y = x.to(device), y.to(device)
# logits, loss = model(x, y)

In [None]:
epochs = 60
for epoch in range(epochs):
    t0 = time.time()
    last_epoch = epochs - 1

    model.train()
    accumulated_loss = 0.0
    optmizer.zero_grad()
    # using accumulated loss
    for mini_epoch in range(mini_epochs):
        x, y = data_loader.next_batch()
        x, y = x.to(device), y.to(device)
    
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        loss /= mini_epochs
        accumulated_loss += loss.detach()
    
        if ddp:
            model.require_backward_grad_sync = (mini_epoch == mini_epochs-1)
        loss.backward()
    if ddp:
        dist.all_reduce(accumulated_loss, op=dist.ReduceOp.AVG)
    
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # change lr
    lr = get_lr(epoch, warmup_lr_steps, max_lr, min_lr, epochs)
    for param_group in optmizer.param_groups:
        param_group['lr'] = lr
    optmizer.step()
    
    torch.cuda.synchronize()
    t1 = time.time()
    dt = t1-t0

    # wrtie to the file losses
    if master_process and epoch%5==0:
        print(f"epoch: {epoch}, loss: {accumulated_loss:.5f}, norm: {norm:.5f}, time: {dt*1000:.2f}ms, tok/s: {data_loader.B*data_loader.T*mini_epochs*ddp_world_size/dt:.2f}")
        with open(log_file, "a") as f:
            f.write(f"epoch:{epoch} loss:{accumulated_loss.item():.5f}\n")
if ddp:
    destroy_process_group()