# milklm 


In [None]:
%%file custom_data.py
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from tiktoken import encoding_for_model  # Assuming this is your tokenizer

class StreamingParquetDataset(IterableDataset):
    def __init__(self, dataset, config ):
        self.dataset        =  dataset
        self.block_size     = config.block_size
        self.batch_size     = config.batch_size
        self.tokenizer_name = config.tokenizer_name
        self.buffer_size    = config.buffer_size        
        self.buffer_size    = self.buffer_size * self.batch_size * (self.block_size + 1)  # Total tokens in buffer
        self.tokenizer      = encoding_for_model(self.tokenizer_name)
        self.buffer = []
        self.iterator = iter(self.dataset)
        self.load_next_buffer()

    def load_next_buffer(self):
        self.buffer = []
        while len(self.buffer) < self.buffer_size:
            try:
                item = next(self.iterator)
                text = item['text']
                tokens = self.tokenizer.encode(text)
                self.buffer.extend(tokens)
            except StopIteration:
                if not self.buffer:
                    raise
                else:
                    break

    def __iter__(self):
        self.buffer = []
        self.iterator = iter(self.dataset)
        return self

    def __next__(self):
        if len(self.buffer) < self.batch_size * (self.block_size + 1):
            self.load_next_buffer()

        start = 0
        end = self.batch_size * (self.block_size + 1)
        chunk_tokens = self.buffer[start:end]
        self.buffer = self.buffer[end:]

        if len(chunk_tokens) < self.batch_size * (self.block_size + 1):
            raise StopIteration

        x = torch.tensor(chunk_tokens[:self.batch_size * self.block_size], dtype=torch.long).view(self.batch_size, self.block_size)
        y = torch.tensor(chunk_tokens[1:self.batch_size * self.block_size + 1], dtype=torch.long).view(self.batch_size, self.block_size)

        return x, y

class DataLoaderLite:
    def __init__(self,  dataset , config):
        self.dataset = dataset
        self.batch_size = config.batch_size
        self.block_size = config.block_size
        self.iterator = iter(self.dataset)

    def __iter__(self):
        self.iterator = iter(self.dataset)
        return self

    def __next__(self):
        return next(self.iterator)



In [None]:
%%file model.py
import torch
import torch.nn as nn
from modules import Block
from config import ModelConfig

class Transformer(nn.Module):
    def __init__(self, config: ModelConfig, gradient_checkpointing=False):
        super().__init__()
        self.config = config
        self.gradient_checkpointing = gradient_checkpointing
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.embedding_size),
            'h': nn.ModuleList([Block(config, i, gradient_checkpointing) for i in range(config.num_layers)]),
            'ln_f': nn.LayerNorm(config.embedding_size),
        })
        self.lm_head = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.num_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.max_context_length, f"Cannot forward sequence of length {T}, block size is only {self.config.max_context_length}"
        tok_emb = self.transformer.wte(idx)
        x = tok_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


In [None]:
%%file utils.py
import torch
import os
import json
import numpy as np
from tiktoken import get_encoding

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def load_tokens(filename, tokenizer_name='gpt-4o'):
    """Load and tokenize data from a file using tiktoken."""
    with open(filename, 'r', encoding='utf-8') as file:
        data = file.read()

    # Use tiktoken's get_encoding to obtain an encoder for GPT2
    encoder = get_encoding(tokenizer_name)
    tokens = encoder.encode(data)
    tokens = np.array(tokens, dtype=np.int32)

    return torch.tensor(tokens, dtype=torch.long)
 

def save_checkpoint(model, optimizer, model_config, train_config, data_config, epoch, loss):
    # Extract necessary details from train_config
    formatted_exp_num = f"{train_config.exp_num:03d}"
    exp_name = train_config.exp_name
    
    # Create directory structure using exp_name and formatted exp_num
    checkpoint_dir = os.path.join(exp_name, formatted_exp_num, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Save the model state
    model_path = os.path.join(checkpoint_dir, f'model_state_{epoch:04d}.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': {'model_config': model_config.__dict__, 'data_config': data_config.__dict__, 'train_config': train_config.__dict__},
        'epoch': epoch,
        'loss': loss
    }, model_path)

    # Save the optimizer state
    optimizer_path = os.path.join(checkpoint_dir, f'optimizer_state_{epoch:04d}.pth')
    torch.save({
        'optimizer_state_dict': optimizer.state_dict()
    }, optimizer_path)


def load_checkpoint(exp_name, exp_num, epoch, model, optimizer):
    formatted_exp_num = f"{exp_num:03d}"
    checkpoint_dir = os.path.join(exp_name, formatted_exp_num, 'checkpoints')
    model_path = os.path.join(checkpoint_dir, f'model_state_{epoch:04d}.pth')
    optimizer_path = os.path.join(checkpoint_dir, f'optimizer_state_{epoch:04d}.pth')

    # Load model state
    model_checkpoint = torch.load(model_path)
    model.load_state_dict(model_checkpoint['model_state_dict'])

    # Load optimizer state
    optimizer_checkpoint = torch.load(optimizer_path)
    optimizer.load_state_dict(optimizer_checkpoint['optimizer_state_dict'])

    return model_checkpoint.get('epoch'), model_checkpoint.get('loss'), model_checkpoint.get('config')




In [None]:
%%file eval.py
import torch
import argparse
from model import Transformer
from config import ModelConfig
from utils import load_checkpoint
from custom_data import DataLoaderLite

def evaluate():
    parser = argparse.ArgumentParser(description="Evaluate a GPT-like model.")
    parser.add_argument("--config", type=str, required=True, help="Path to config file.")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint.")
    args = parser.parse_args()

    config = torch.load(args.config)
    model_config = ModelConfig(**config['model_config'])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Transformer(model_config).to(device)
    optimizer = torch.optim.AdamW(model.parameters())
    load_checkpoint(args.checkpoint, model, optimizer)

    val_loader = DataLoaderLite(config['data_config']['batch_size'], config['data_config']['block_size'], config['data_config']['dataset_dir'], "val")
    model.eval()
    val_loader.reset()
    total_loss = 0
    with torch.no_grad():
        for _ in range(len(val_loader.tokens) // (config['data_config']['batch_size'] * config['data_config']['block_size'])):
            x, y = val_loader.next_batch()
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            total_loss += loss.item()
    
    print(f"Validation Loss: {total_loss / (len(val_loader.tokens) // (config['data_config']['batch_size'] * config['data_config']['block_size']))}")

if __name__ == "__main__":
    evaluate()


In [None]:
%%file train.py

import os
import time
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
from model import Transformer
from config import ModelConfig, DataConfig, TrainConfig
from custom_data import StreamingParquetDataset, DataLoaderLite
from utils import count_parameters, save_checkpoint, load_checkpoint, load_tokens
from tiktoken import encoding_for_model
from torch.cuda.amp import autocast, GradScaler
from datasets import load_dataset

from torch.optim.lr_scheduler import LambdaLR

def train_iter(model, dataloader, optimizer, scheduler, criterion, device, total_iters, max_total_iters, model_config, train_config, data_config, batch_size, block_size, tokenizer):
    model.train()
    total_loss = 0
    start_time = time.time()
    iter_batch = 0

    for batch_idx, batch in enumerate(dataloader):
        if total_iters >= max_total_iters:
            break

        x, y = batch
        x, y = x.to(device), y.to(device)
        start_forward = time.time()
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            logits, loss = model(x, y)
        forward_duration = (time.time() - start_forward) * 1000
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.value_clip_grad_norm ) # default 1. use 0.5 
        
        start_backward = time.time()
        optimizer.zero_grad()
        loss.backward()
            
        optimizer.step()
        scheduler.step()   # Update learning rate   
        
        backward_duration = (time.time() - start_backward) * 1000
        iter_duration = (time.time() - start_forward)
       
        
    # Print detailed log for each iteration        
        if total_iters % train_config.print_interval_iteration ==0 : 

            perplexity = torch.exp(loss).item()
            loss_val = loss.item()
            norm_val = norm.item()
            total_loss += loss_val
    
            tokens_total_seen = total_iters * (batch_size * block_size)
            tokens_iter = (batch_size * block_size)
            speed = (tokens_iter / iter_duration) / 1000
    
            # Decode input and generated text with line breaks replaced by \n
            input_text = tokenizer.decode(x[0].tolist()).replace('\n', '\\n')
            gt_text = tokenizer.decode(y[0].tolist()).replace('\n', '\\n')
                    # Check for valid token IDs before decoding
            try:
                pred_text = tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist()).replace('\n', '\\n')
            except pyo3_runtime.PanicException as e:
                pred_text = "include <INVALID TOKEN ID>"                              
            
            print_text_len = min(data_config.block_size,  train_config.print_token_len) 
            str_iter = f'\nIter {total_iters:>6d}/{max_total_iters}, B{iter_batch + 1:>3d}, {tokens_total_seen / 1000000:3.1f} Mt'
            str_loss = f'Loss {loss_val:>9.6f}, PPL {perplexity:>9.2f} {norm_val:>9.6f} | F {forward_duration:>5.1f} ms, B {backward_duration:>5.1f} ms | {speed:>4.2f} Kt/s'
            str_text = f'O: {gt_text[-print_text_len:]} | P: {pred_text[-print_text_len:]} '            
            print(f"{str_iter} | {str_loss} | {str_text} ||", end='')
        total_iters += 1
        iter_batch += 1

        # Save checkpoint every 1000 iterations
        if total_iters % train_config.save_interval_iter == 0:
            save_checkpoint(model, optimizer, model_config, train_config, data_config, total_iters, total_loss / (batch_idx + 1))

    return total_loss / (batch_idx + 1), total_iters

def train(model, train_loader, optimizer, scheduler, criterion, device, model_config, train_config, data_config):
    total_iters = 0
    tokenizer = encoding_for_model(data_config.tokenizer_name)
    max_total_iters = train_config.max_total_iters
    
 

    while total_iters < max_total_iters:
        train_loss, total_iters = train_iter(model,  train_loader, optimizer, scheduler, criterion, device, total_iters, max_total_iters, model_config, train_config, data_config,data_config.batch_size, data_config.block_size, tokenizer)
 
        
        print(f'Total Iterations: {total_iters}/{max_total_iters} | Loss: {train_loss:>8.5f}')

def get_lr_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)
    

def main():
    print("start")
    parser = argparse.ArgumentParser(description="Train a GPT-like model.")
    parser.add_argument("--batch_size", type=int, help="Batch size for training.")
    parser.add_argument("--block_size", type=int, help="Block size for training.")
    parser.add_argument("--learning_rate", type=float, help="Learning rate for training.")
    parser.add_argument("--max_total_iters", type=int, help="Maximum number of iterations for training.")
    parser.add_argument("--gradient_checkpointing", action='store_true', help="Enable gradient checkpointing.")
    args = parser.parse_args()
    tic = time.time()
    # Load configurations from config.py
    model_config = ModelConfig()
    data_config = DataConfig()
    train_config = TrainConfig()

    # Update configurations with command-line arguments
    if args.batch_size:
        data_config.batch_size = args.batch_size
    if args.block_size:
        data_config.block_size = args.block_size
    if args.learning_rate:
        train_config.learning_rate = args.learning_rate
    if args.max_total_iters:
        train_config.max_total_iters = args.max_total_iters
    if args.gradient_checkpointing:
        train_config.gradient_checkpointing = args.gradient_checkpointing

    print(model_config)
    print(data_config)
    print(train_config)

    # Set the float32 matmul precision to 'medium'
    torch.set_float32_matmul_precision('medium')

    # Setup device and model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(1337)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)

    model = Transformer(model_config, train_config.gradient_checkpointing).to(device)
    toc = time.time()
    dur = toc - tic
    print(f"model configure {dur:4.1f}sec")

    tic = time.time()

    optimizer = optim.AdamW(model.parameters(), lr=train_config.learning_rate, betas=(0.9, 0.95), eps=1e-8)
    criterion = nn.CrossEntropyLoss()
    
    # Define the learning rate scheduler
    num_training_steps = train_config.max_total_iters
    scheduler = get_lr_schedule_with_warmup(optimizer, num_warmup_steps=train_config.num_warmup_steps, num_training_steps=num_training_steps)


    toc = time.time()
    dur = toc - tic
    print(f"configure optimizer during {dur:4.1f}sec")
    print(f"{count_parameters(model) / 1000000} M Params")
    print(model)

    # Setup data loaders
    print("prepare dataset")
    tic = time.time()
    # Load and prepare data
    try:
        dataset = load_dataset('parquet', data_files=f'{data_config.dataset_dir}/**/*.parquet', split='train', streaming=True)
    except ValueError as e:
        print(e)
        return

    train_dataset = StreamingParquetDataset(dataset, data_config )
    
    toc = time.time()
    dur = toc - tic
    print(f"load Parquet Data {dur:4.2f}sec")

    tic = time.time()
    train_loader = DataLoaderLite(train_dataset,  data_config )
    toc = time.time()
    dur_data = toc - tic
    print(f"configure DataLoaderLite during {dur_data:4.1f} sec")

    # Train the model
    print("start train")
    train(model, train_loader, optimizer, scheduler, criterion, device, model_config, train_config, data_config)

if __name__ == "__main__":
    main()


In [None]:
%%file modules.py 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        assert config.embedding_size % config.num_heads == 0
        self.c_attn = nn.Linear(config.embedding_size, 3 * config.embedding_size)
        self.c_proj = nn.Linear(config.embedding_size, config.embedding_size)
        self.n_head = config.num_heads
        self.n_embd = config.embedding_size
        self.layer_idx = layer_idx
        self.debug = config.debug
        self.proj_ratio = config.proj_ratio
        self.proj_ratio_min = config.proj_ratio_min
        self.num_keep_boundary_chunk = config.num_keep_boundary_chunk

        # Initialize cache for projection matrices
        self.proj_matrix_cache = {}

    def _apply_rotary_embedding(self, x, seq_len):
        dim = x.shape[-1]
        dtype = x.dtype

        theta = 10000.0
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        freqs = torch.outer(torch.arange(seq_len), freqs).to(x.device)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

        if dtype not in [torch.float32, torch.float64]:
            x_temp = x.float()
        else:
            x_temp = x

        x_temp = x_temp.view(*x_temp.shape[:-1], x_temp.shape[-1] // 2, 2)
        x_temp = torch.view_as_complex(x_temp)

        x_temp = x_temp * freqs_cis
        x_rot = torch.view_as_real(x_temp).view(*x.shape[:-1], -1)
        if x.dtype != x_rot.dtype:
            x_rot = x_rot.to(dtype)

        return x_rot

    def _generate_projection_matrix(self, T, P, exponent=2):
        key = (T, P, exponent)
        if key in self.proj_matrix_cache:
            return self.proj_matrix_cache[key]

        start = 0
        end = 2595 * np.log10(1 + (T / 2) / 700.0)
        steps = np.linspace(start, end, P + 2)
        scales = 700 * (10**(steps / 2595) - 1)

        bins = np.floor((T + 1) * scales / (T / 2)).astype(int)
        bins = np.clip(bins, 0, T)  # Ensure bins are within the valid range
        basis_matrix = np.zeros((P, T))

        for p in range(1, P + 1):
            start_bin = bins[p - 1]
            mid_bin = bins[p]
            end_bin = bins[min(p + 1, len(bins) - 1)]

            for t in range(start_bin, mid_bin):
                if bins[p] - bins[p - 1] != 0:
                    basis_matrix[p - 1, t] = ((t - bins[p - 1]) / (bins[p] - bins[p - 1])) ** exponent
            for t in range(mid_bin, end_bin):
                if bins[min(p + 1, len(bins) - 1)] - bins[p] != 0:
                    basis_matrix[p - 1, t] = ((bins[min(p + 1, len(bins) - 1)] - t) / (bins[min(p + 1, len(bins) - 1)] - bins[p])) ** exponent

        self.proj_matrix_cache[key] = basis_matrix
        return basis_matrix

    def _get_boundary_tokens(self, T):
        chunk_size = T // self.proj_ratio
        num_keep = self.num_keep_boundary_chunk
        tokens_front_boundary_chunk = torch.arange(0, num_keep * chunk_size)
        tokens_last_boundary_chunk = torch.arange(T - num_keep * chunk_size, T)
        return tokens_front_boundary_chunk, tokens_last_boundary_chunk

    def _get_body_tokens(self, T):
        chunk_size = T // self.proj_ratio
        num_keep = self.num_keep_boundary_chunk
        tokens_body_chunk = torch.arange(num_keep * chunk_size, T - num_keep * chunk_size)
        return tokens_body_chunk

    def _project_tokens(self, x, tokens):
        P = tokens.numel() // self.proj_ratio
        projection_matrix = self._generate_projection_matrix(tokens.numel(), P, exponent=3)
        projection_matrix = torch.tensor(projection_matrix, dtype=torch.float32, device=x.device).T
        x_se = torch.einsum('bhtd,tp->bhpd', x[:, :, tokens], projection_matrix)
        return x_se

    def forward(self, x):
        B, T, C = x.size()
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} input ")

        qkv = self.c_attn(x)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after qkv projection")

        q, k, v = qkv.split(self.n_embd, dim=2)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of Q {q.shape} after split ")
            print(f"{self.layer_idx}L - mhsa - shape of K {k.shape} after split")
            print(f"{self.layer_idx}L - mhsa - shape of V {v.shape} after split")

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of Q {q.shape} after view.transpose(1,2) ")
            print(f"{self.layer_idx}L - mhsa - shape of K {k.shape} after view.transpose(1,2)")
            print(f"{self.layer_idx}L - mhsa - shape of V {v.shape} after view.transpose(1,2)")

        q = self._apply_rotary_embedding(q, T)
        k = self._apply_rotary_embedding(k, T)
        # v = self._apply_rotary_embedding(v, T)

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of Q {q.shape} after RoPE ")
            print(f"{self.layer_idx}L - mhsa - shape of K {k.shape} after RoPE")
            print(f"{self.layer_idx}L - mhsa - shape of V {v.shape} after RoPE")

        tokens_front_boundary_chunk, tokens_last_boundary_chunk = self._get_boundary_tokens(T)
        tokens_body_chunk = self._get_body_tokens(T)

        k_se_front_boundary = k[:, :, tokens_front_boundary_chunk]
        v_se_front_boundary = v[:, :, tokens_front_boundary_chunk]

        k_se_body = self._project_tokens(k, tokens_body_chunk)
        v_se_body = self._project_tokens(v, tokens_body_chunk)

        k_se_last_boundary = k[:, :, tokens_last_boundary_chunk]
        v_se_last_boundary = v[:, :, tokens_last_boundary_chunk]

        k_se = torch.cat([k_se_front_boundary, k_se_body, k_se_last_boundary], dim=-2)
        v_se = torch.cat([v_se_front_boundary, v_se_body, v_se_last_boundary], dim=-2)

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of k_se {k_se.shape} after projection")
            print(f"{self.layer_idx}L - mhsa - shape of v_se {v_se.shape} after projection")

        att = (q @ k_se.transpose(-2, -1)) * (1.0 / math.sqrt(k_se.size(-1)))
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of att score matrix {att.shape} after matmul")

        mask = torch.tril(torch.ones((T, k_se.size(-2)), device=x.device, dtype=torch.bool))
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of mask {mask.shape}")

        mask = mask[None, None, :, :]  # Add dimensions for batch and head
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of mask {mask.shape} after reshape")

        att = att.masked_fill(~mask[:, None, None, :], float('-inf'))

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of att score matrix {att.shape} after mask fill")

        #att = F.sigmoid(att)
        att =  F.softmax(att, dim=-1)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of att score matrix {att.shape} after sigmoid")

        x = att @ v_se  # (B,nh,T,P) x (B,nh,P,hs) --> ( B, nh, T, hs)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after matmul")

        x = x.transpose(1, 2).contiguous().view(B, T, C)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after view(B,T,C)")

        x = self.c_proj(x)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after out projection")
        return x

# Rest of the code remains unchanged



class MLP(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.c_fc = nn.Linear(config.embedding_size, int(config.embedding_size * config.ffn_scale_ratio))
        self.gelu = nn.GELU()
        self.silu = F.silu
        self.c_proj = nn.Linear(int(config.embedding_size * config.ffn_scale_ratio), config.embedding_size)
        self.layer_idx = layer_idx
        self.debug = config.debug

    def forward(self, x):
        x_fc = self.c_fc(x)
        if self.debug:
            print(f"{self.layer_idx}L - mlp - shape of x_fc {x_fc.shape}")
        x_gelu = self.silu(x_fc)
        x_proj = self.c_proj(x_gelu)
        if self.debug:
            print(f"{self.layer_idx}L - mlp - shape of x_proj {x_proj.shape}")
        return x_proj

class Block(nn.Module):
    def __init__(self, config, layer_idx, gradient_checkpointing=False):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.embedding_size)
        self.attn = CausalSelfAttention(config, layer_idx)
        self.ln_2 = nn.LayerNorm(config.embedding_size)
        self.mlp = MLP(config, layer_idx)
        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, x):
        if self.gradient_checkpointing and self.training:
            x = x + checkpoint(self.attn, self.ln_1(x))
            x = x + checkpoint(self.mlp, self.ln_2(x))
        else:
            x = x + self.attn(self.ln_1(x))
            x = x + self.mlp(self.ln_2(x))
        return x


In [None]:
%%file config.py
from dataclasses import dataclass
import torch

@dataclass
class ModelConfig:
    embedding_size: int = 1024
    num_layers: int = 32
    num_heads: int = 16
    ffn_scale_ratio: float = 0.5
    max_context_length: int = 8192
    vocab_size: int = 200019  #  200960 for 200019 for gpt-4o |   251264 for  250257 for gpt2 with 1000 token buffers
    debug: bool = False  # Add debug flag
    proj_ratio_min : int = 2
    proj_ratio : int = 10
    num_keep_boundary_chunk : int = 2
    

@dataclass
class DataConfig:
    batch_size: int =60
    block_size: int = 64
    buffer_size : int = 10 
    tokenizer_name: str = "gpt-4o"
    dataset_dir: str = "/mnt/e/jupyter/gpt2_scratch2/datasets/finewebedu/CC-MAIN-2024-10"
    truncate_limit: int = 1000000000
    shuffle_tokens: bool = False   
    mask_random: bool = False   

@dataclass
class TrainConfig:
    learning_rate: float = 1e-3
    dtype: str = "bfloat16"
    optimizer_options: dict = None
    exp_name: str = "experiment"
    exp_num: int = 8
    save_interval_epoch: int = 10
    save_interval_iter: int = 2000
    print_interval_iteration: int = 10
    eval_interval_iteration: int = 250
    num_epochs: int = 200
    max_total_iters : int = 3000000
    num_warmup_steps : int = 4000
    gradient_checkpointing : bool = True
    value_clip_grad_norm : float = 0.5
    lr_step : int = 1000
    print_token_len : int = 64


In [None]:
!python train.py # 