# milklm 


In [None]:
%%file model.py
import torch
import torch.nn as nn
from modules import Block
from config import ModelConfig

class Transformer(nn.Module):
    def __init__(self, config: ModelConfig, gradient_checkpointing=False):
        super().__init__()
        self.config = config
        self.gradient_checkpointing = gradient_checkpointing
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.embedding_size),
            'h': nn.ModuleList([Block(config, i, gradient_checkpointing) for i in range(config.num_layers)]),
            'ln_f': nn.LayerNorm(config.embedding_size),
        })
        self.lm_head = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.num_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.max_context_length, f"Cannot forward sequence of length {T}, block size is only {self.config.max_context_length}"
        tok_emb = self.transformer.wte(idx)
        x = tok_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


In [None]:
%%file utils.py
import torch
import os
import json
import numpy as np
from tiktoken import get_encoding

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def load_tokens(filename, tokenizer_name='gpt-4o'):
    """Load and tokenize data from a file using tiktoken."""
    with open(filename, 'r', encoding='utf-8') as file:
        data = file.read()

    # Use tiktoken's get_encoding to obtain an encoder for GPT2
    encoder = get_encoding(tokenizer_name)
    tokens = encoder.encode(data)
    tokens = np.array(tokens, dtype=np.int32)

    return torch.tensor(tokens, dtype=torch.long)
 

def save_checkpoint(model, optimizer, model_config, train_config, data_config, epoch, loss):
    # Extract necessary details from train_config
    formatted_exp_num = f"{train_config.exp_num:03d}"
    exp_name = train_config.exp_name
    
    # Create directory structure using exp_name and formatted exp_num
    checkpoint_dir = os.path.join(exp_name, formatted_exp_num, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Save the model state
    model_path = os.path.join(checkpoint_dir, f'model_state_{epoch:04d}.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': {'model_config': model_config.__dict__, 'data_config': data_config.__dict__, 'train_config': train_config.__dict__},
        'epoch': epoch,
        'loss': loss
    }, model_path)

    # Save the optimizer state
    optimizer_path = os.path.join(checkpoint_dir, f'optimizer_state_{epoch:04d}.pth')
    torch.save({
        'optimizer_state_dict': optimizer.state_dict()
    }, optimizer_path)


def load_checkpoint(exp_name, exp_num, epoch, model, optimizer):
    formatted_exp_num = f"{exp_num:03d}"
    checkpoint_dir = os.path.join(exp_name, formatted_exp_num, 'checkpoints')
    model_path = os.path.join(checkpoint_dir, f'model_state_{epoch:04d}.pth')
    optimizer_path = os.path.join(checkpoint_dir, f'optimizer_state_{epoch:04d}.pth')

    # Load model state
    model_checkpoint = torch.load(model_path)
    model.load_state_dict(model_checkpoint['model_state_dict'])

    # Load optimizer state
    optimizer_checkpoint = torch.load(optimizer_path)
    optimizer.load_state_dict(optimizer_checkpoint['optimizer_state_dict'])

    return model_checkpoint.get('epoch'), model_checkpoint.get('loss'), model_checkpoint.get('config')




In [None]:
%%file eval.py
import torch
import argparse
from model import Transformer
from config import ModelConfig
from utils import load_checkpoint
from custom_data import DataLoaderLite

def evaluate():
    parser = argparse.ArgumentParser(description="Evaluate a GPT-like model.")
    parser.add_argument("--config", type=str, required=True, help="Path to config file.")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint.")
    args = parser.parse_args()

    config = torch.load(args.config)
    model_config = ModelConfig(**config['model_config'])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Transformer(model_config).to(device)
    optimizer = torch.optim.AdamW(model.parameters())
    load_checkpoint(args.checkpoint, model, optimizer)

    val_loader = DataLoaderLite(config['data_config']['batch_size'], config['data_config']['block_size'], config['data_config']['dataset_dir'], "val")
    model.eval()
    val_loader.reset()
    total_loss = 0
    with torch.no_grad():
        for _ in range(len(val_loader.tokens) // (config['data_config']['batch_size'] * config['data_config']['block_size'])):
            x, y = val_loader.next_batch()
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            total_loss += loss.item()
    
    print(f"Validation Loss: {total_loss / (len(val_loader.tokens) // (config['data_config']['batch_size'] * config['data_config']['block_size']))}")

if __name__ == "__main__":
    evaluate()


In [None]:
%%file modules.py 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        assert config.embedding_size % config.num_heads == 0
        self.c_attn = nn.Linear(config.embedding_size, 3 * config.embedding_size)
        self.c_proj = nn.Linear(config.embedding_size, config.embedding_size)
        self.n_head = config.num_heads
        self.n_embd = config.embedding_size
        self.layer_idx = layer_idx
        self.debug = config.debug
        self.proj_ratio = config.proj_ratio
        self.proj_ratio_min = config.proj_ratio_min
        self.num_keep_boundary_chunk = config.num_keep_boundary_chunk

        # Initialize cache for projection matrices
        self.proj_matrix_cache = {}

    def _apply_rotary_embedding(self, x, seq_len):
        dim = x.shape[-1]
        dtype = x.dtype

        theta = 10000.0
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        freqs = torch.outer(torch.arange(seq_len), freqs).to(x.device)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

        if dtype not in [torch.float32, torch.float64]:
            x_temp = x.float()
        else:
            x_temp = x

        x_temp = x_temp.view(*x_temp.shape[:-1], x_temp.shape[-1] // 2, 2)
        x_temp = torch.view_as_complex(x_temp)

        x_temp = x_temp * freqs_cis
        x_rot = torch.view_as_real(x_temp).view(*x.shape[:-1], -1)
        if x.dtype != x_rot.dtype:
            x_rot = x_rot.to(dtype)

        return x_rot

    def _generate_projection_matrix(self, T, P, exponent=2):
        key = (T, P, exponent)
        if key in self.proj_matrix_cache:
            return self.proj_matrix_cache[key]

        start = 0
        end = 2595 * np.log10(1 + (T / 2) / 700.0)
        steps = np.linspace(start, end, P + 2)
        scales = 700 * (10**(steps / 2595) - 1)

        bins = np.floor((T + 1) * scales / (T / 2)).astype(int)
        bins = np.clip(bins, 0, T)  # Ensure bins are within the valid range
        basis_matrix = np.zeros((P, T))

        for p in range(1, P + 1):
            start_bin = bins[p - 1]
            mid_bin = bins[p]
            end_bin = bins[min(p + 1, len(bins) - 1)]

            for t in range(start_bin, mid_bin):
                if bins[p] - bins[p - 1] != 0:
                    basis_matrix[p - 1, t] = ((t - bins[p - 1]) / (bins[p] - bins[p - 1])) ** exponent
            for t in range(mid_bin, end_bin):
                if bins[min(p + 1, len(bins) - 1)] - bins[p] != 0:
                    basis_matrix[p - 1, t] = ((bins[min(p + 1, len(bins) - 1)] - t) / (bins[min(p + 1, len(bins) - 1)] - bins[p])) ** exponent

        self.proj_matrix_cache[key] = basis_matrix
        return basis_matrix

    def _get_boundary_tokens(self, T):
        chunk_size = T // self.proj_ratio
        num_keep = self.num_keep_boundary_chunk
        tokens_front_boundary_chunk = torch.arange(0, num_keep * chunk_size)
        tokens_last_boundary_chunk = torch.arange(T - num_keep * chunk_size, T)
        return tokens_front_boundary_chunk, tokens_last_boundary_chunk

    def _get_body_tokens(self, T):
        chunk_size = T // self.proj_ratio
        num_keep = self.num_keep_boundary_chunk
        tokens_body_chunk = torch.arange(num_keep * chunk_size, T - num_keep * chunk_size)
        return tokens_body_chunk

    def _project_tokens(self, x, tokens):
        P = tokens.numel() // self.proj_ratio
        projection_matrix = self._generate_projection_matrix(tokens.numel(), P, exponent=3)
        projection_matrix = torch.tensor(projection_matrix, dtype=torch.float32, device=x.device).T
        x_se = torch.einsum('bhtd,tp->bhpd', x[:, :, tokens], projection_matrix)
        return x_se

    def forward(self, x):
        B, T, C = x.size()
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} input ")

        qkv = self.c_attn(x)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after qkv projection")

        q, k, v = qkv.split(self.n_embd, dim=2)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of Q {q.shape} after split ")
            print(f"{self.layer_idx}L - mhsa - shape of K {k.shape} after split")
            print(f"{self.layer_idx}L - mhsa - shape of V {v.shape} after split")

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of Q {q.shape} after view.transpose(1,2) ")
            print(f"{self.layer_idx}L - mhsa - shape of K {k.shape} after view.transpose(1,2)")
            print(f"{self.layer_idx}L - mhsa - shape of V {v.shape} after view.transpose(1,2)")

        q = self._apply_rotary_embedding(q, T)
        k = self._apply_rotary_embedding(k, T)
        # v = self._apply_rotary_embedding(v, T)

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of Q {q.shape} after RoPE ")
            print(f"{self.layer_idx}L - mhsa - shape of K {k.shape} after RoPE")
            print(f"{self.layer_idx}L - mhsa - shape of V {v.shape} after RoPE")

        tokens_front_boundary_chunk, tokens_last_boundary_chunk = self._get_boundary_tokens(T)
        tokens_body_chunk = self._get_body_tokens(T)

        k_se_front_boundary = k[:, :, tokens_front_boundary_chunk]
        v_se_front_boundary = v[:, :, tokens_front_boundary_chunk]

        k_se_body = self._project_tokens(k, tokens_body_chunk)
        v_se_body = self._project_tokens(v, tokens_body_chunk)

        k_se_last_boundary = k[:, :, tokens_last_boundary_chunk]
        v_se_last_boundary = v[:, :, tokens_last_boundary_chunk]

        k_se = torch.cat([k_se_front_boundary, k_se_body, k_se_last_boundary], dim=-2)
        v_se = torch.cat([v_se_front_boundary, v_se_body, v_se_last_boundary], dim=-2)

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of k_se {k_se.shape} after projection")
            print(f"{self.layer_idx}L - mhsa - shape of v_se {v_se.shape} after projection")

        att = (q @ k_se.transpose(-2, -1)) * (1.0 / math.sqrt(k_se.size(-1)))
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of att score matrix {att.shape} after matmul")

        mask = torch.tril(torch.ones((T, k_se.size(-2)), device=x.device, dtype=torch.bool))
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of mask {mask.shape}")

        mask = mask[None, None, :, :]  # Add dimensions for batch and head
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of mask {mask.shape} after reshape")

        att = att.masked_fill(~mask[:, None, None, :], float('-inf'))

        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of att score matrix {att.shape} after mask fill")

        #att = F.sigmoid(att)
        att =  F.softmax(att, dim=-1)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of att score matrix {att.shape} after sigmoid")

        x = att @ v_se  # (B,nh,T,P) x (B,nh,P,hs) --> ( B, nh, T, hs)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after matmul")

        x = x.transpose(1, 2).contiguous().view(B, T, C)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after view(B,T,C)")

        x = self.c_proj(x)
        if self.debug:
            print(f"{self.layer_idx}L - mhsa - shape of x {x.shape} after out projection")
        return x

# Rest of the code remains unchanged



class MLP(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.c_fc = nn.Linear(config.embedding_size, int(config.embedding_size * config.ffn_scale_ratio))
        self.gelu = nn.GELU()
        self.silu = F.silu
        self.c_proj = nn.Linear(int(config.embedding_size * config.ffn_scale_ratio), config.embedding_size)
        self.layer_idx = layer_idx
        self.debug = config.debug

    def forward(self, x):
        x_fc = self.c_fc(x)
        if self.debug:
            print(f"{self.layer_idx}L - mlp - shape of x_fc {x_fc.shape}")
        x_gelu = self.silu(x_fc)
        x_proj = self.c_proj(x_gelu)
        if self.debug:
            print(f"{self.layer_idx}L - mlp - shape of x_proj {x_proj.shape}")
        return x_proj

class Block(nn.Module):
    def __init__(self, config, layer_idx, gradient_checkpointing=False):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.embedding_size)
        self.attn = CausalSelfAttention(config, layer_idx)
        self.ln_2 = nn.LayerNorm(config.embedding_size)
        self.mlp = MLP(config, layer_idx)
        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, x):
        if self.gradient_checkpointing and self.training:
            x = x + checkpoint(self.attn, self.ln_1(x))
            x = x + checkpoint(self.mlp, self.ln_2(x))
        else:
            x = x + self.attn(self.ln_1(x))
            x = x + self.mlp(self.ln_2(x))
        return x


In [71]:
%%file config.py
from dataclasses import dataclass
import torch

@dataclass
class ModelConfig:
    embedding_size: int = 1024
    num_layers: int = 32
    num_heads: int = 32
    ffn_scale_ratio: float = 0.5
    max_context_length: int = 8192
    vocab_size: int = 200019  # 200960 for gpt-4o | 251264 for gpt2 with 1000 token buffers
    debug: bool = False  # Add debug flag
    proj_ratio_min: int = 2
    proj_ratio: int = 10
    num_keep_boundary_chunk: int = 2

 
@dataclass
class DataConfig:
    batch_size: int = 60
    block_size: int = 64
    buffer_size: int = 1 
    tokenizer_name: str = "gpt-4o"
    dataset_dir: str = "/mnt/e/jupyter/gpt2_scratch2/datasets/finewebedu/CC-MAIN-2024-10"
    truncate_limit: int = 1000000000
    shuffle_tokens: bool = False   
    mask_random: bool = False
    chunk_sizes: list = (8, 16, 32, 64 )  # Dynamic chunk sizes
    hop_size_factor: int = 4  # Factor to determine hop size from chunk size


@dataclass
class TrainConfig:
    learning_rate: float = 1e-3
    dtype: str = "bfloat16"
    optimizer_options: dict = None
    exp_name: str = "experiment"
    exp_num: int = 8
    save_interval_epoch: int = 100
    save_interval_iter: int = 20000
    print_interval_iteration: int = 1
    eval_interval_iteration: int = 250
    num_epochs: int = 200
    max_total_iters: int = 3000000
    num_warmup_steps: int = 4000
    gradient_checkpointing: bool = True
    value_clip_grad_norm: float = 0.5
    lr_step: int = 1000
    print_token_len: int = 64


Overwriting config.py


In [77]:
%%file dummy_process.py

import torch
import torch.nn as nn
import torch.optim as optim
from model import Transformer
from config import ModelConfig
from utils import count_parameters
import pynvml
import time

def get_max_batch_size(model, device, chunk_size):
    dummy_input = torch.randint(0, model.config.vocab_size, (1, chunk_size)).to(device)
    dummy_target = torch.randint(0, model.config.vocab_size, (1, chunk_size)).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.95), eps=1e-8)
    criterion = nn.CrossEntropyLoss()
    
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    
    batch_size = 0
    forward_times = []
    backward_times = []
    
    while True:
        inputs = dummy_input.repeat(batch_size + 1, 1)
        targets = dummy_target.repeat(batch_size + 1, 1)
        try:
            # Forward pass
            start_forward = time.time()
            logits, loss = model(inputs, targets)
            forward_duration = (time.time() - start_forward) * 1000  # in ms

            # Backward pass
            optimizer.zero_grad()
            start_backward = time.time()
            loss.backward()
            optimizer.step()
            backward_duration = (time.time() - start_backward) * 1000  # in ms
            
            forward_times.append(forward_duration)
            backward_times.append(backward_duration)

            info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            current_memory = info.used / (1024 ** 3)
            free_memory = info.free / (1024 ** 3)
            print(f"Batch Size: {batch_size + 1}, Used Memory: {current_memory:.2f} GiB, Free Memory: {free_memory:.2f} GiB, "
                  f"Forward Time: {forward_duration:.2f} ms, Backward Time: {backward_duration:.2f} ms")
            if free_memory > 1:  # Keep 1 GiB free
                batch_size += 1
            else:
                break
        except torch.cuda.OutOfMemoryError:
            break

    avg_forward_time = sum(forward_times) / len(forward_times)
    avg_backward_time = sum(backward_times) / len(backward_times)

    return batch_size, avg_forward_time, avg_backward_time

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_config = ModelConfig()

    chunk_sizes = [8, 16, 32, 64]
    max_batch_sizes = []
    forward_times = []
    backward_times = []

    for chunk_size in chunk_sizes:
        torch.cuda.empty_cache()  # Clear the GPU cache
        model = Transformer(model_config).to(device)  # Reset model
        optimizer = optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.95), eps=1e-8)  # Reset optimizer
        print(f"Detecting max batch size for chunk size: {chunk_size}")
        max_batch_size, avg_forward_time, avg_backward_time = get_max_batch_size(model, device, chunk_size)
        print(f"Max Batch Size for chunk size {chunk_size}: {max_batch_size}, "
              f"Avg Forward Time: {avg_forward_time:.2f} ms, Avg Backward Time: {avg_backward_time:.2f} ms")
        max_batch_sizes.append(max_batch_size)
        forward_times.append(avg_forward_time)
        backward_times.append(avg_backward_time)

    print(f"Detected max batch sizes: {max_batch_sizes}")
    print(f"Average Forward Times: {forward_times}")
    print(f"Average Backward Times: {backward_times}")
    torch.save(max_batch_sizes, "max_batch_sizes.pt")

if __name__ == "__main__":
    main()


Overwriting dummy_process.py


In [None]:
!python dummy_process.py

In [66]:
%%file buffer_data_loader.py

from datasets import load_dataset

class BufferDataLoader:
    def __init__(self, dataset_dir, buffer_size):
        self.dataset = load_dataset('parquet', data_files=f'{dataset_dir}/**/*.parquet', split='train', streaming=True)
        self.buffer_size = buffer_size
        self.iterator = iter(self.dataset)

    def __iter__(self):
        return self

    def __next__(self):
        buffer = []
        try:
            while len(buffer) < self.buffer_size:
                item = next(self.iterator)
                buffer.append(item['text'])
        except StopIteration:
            if not buffer:
                raise
        return buffer


Overwriting buffer_data_loader.py


In [72]:
%%file chunk_data_loader.py

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tiktoken import encoding_for_model

class ChunkDataLoader:
    def __init__(self, buffer, config, max_batch_sizes, chunk_size):
        self.buffer = buffer
        self.tokenizer = encoding_for_model(config.tokenizer_name)
        self.chunk_size = chunk_size
        self.hop_size = chunk_size // config.hop_size_factor
        self.max_batch_sizes = max_batch_sizes

        self.buffer_tokens = []
        for text in buffer:
            self.buffer_tokens.extend(self.tokenizer.encode(text))
        
        self.num_chunks = (len(self.buffer_tokens) - chunk_size) // self.hop_size + 1
        self.batch_size = max_batch_sizes[config.chunk_sizes.index(chunk_size)]
        self.current_chunk = 0

    def __iter__(self):
        self.current_chunk = 0
        return self

    def __next__(self):
        if self.current_chunk >= self.num_chunks:
            raise StopIteration
        
        start = self.current_chunk * self.hop_size
        end = start + self.chunk_size + 1

        if end > len(self.buffer_tokens):
            raise StopIteration

        chunk = self.buffer_tokens[start:end]
        inputs = torch.tensor(chunk[:-1], dtype=torch.long).unsqueeze(0)
        targets = torch.tensor(chunk[1:], dtype=torch.long).unsqueeze(0)
        
        self.current_chunk += 1

        return inputs, targets, self.current_chunk, self.num_chunks



Overwriting chunk_data_loader.py


In [73]:
%%file train.py

import torch
import torch.nn as nn
import torch.optim as optim
import time
import argparse
from torch.optim.lr_scheduler import LambdaLR
from buffer_data_loader import BufferDataLoader
from chunk_data_loader import ChunkDataLoader
from model import Transformer
from config import ModelConfig, DataConfig, TrainConfig
from tiktoken import encoding_for_model

def train_iter(model, dataloader, optimizer, scheduler, criterion, device, total_iters, max_total_iters, model_config, train_config, data_config, tokenizer):
    model.train()
    total_loss = 0
    start_time = time.time()

    for batch_idx, (inputs, targets, chunk_idx, total_chunks) in enumerate(dataloader):
        if total_iters >= max_total_iters:
            break

        inputs, targets = inputs.to(device), targets.to(device)
        start_forward = time.time()
        logits, loss = model(inputs, targets)
        forward_duration = (time.time() - start_forward) * 1000
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        backward_duration = (time.time() - start_forward - forward_duration / 1000) * 1000

        perplexity = torch.exp(loss).item()
        total_loss += loss.item()

        # Decode input and generated text with line breaks replaced by \n
        input_text = tokenizer.decode(inputs[0].tolist()).replace('\n', '\\n')
        gt_text = tokenizer.decode(targets[0].tolist()).replace('\n', '\\n')
        pred_text = tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist()).replace('\n', '\\n')

        print(f"Iter {total_iters}/{max_total_iters}, Loss: {loss.item():.6f}, Perplexity: {perplexity:.6f}, "
              f"Chunk Size: {dataloader.chunk_size}, Batch Size: {inputs.size(0)}, "
              f"Input: {input_text[:50]}..., Target: {gt_text[:50]}..., Prediction: {pred_text[:50]}..., "
              f"Forward Time: {forward_duration:.2f} ms, Backward Time: {backward_duration:.2f} ms")

        total_iters += 1

    return total_loss / (batch_idx + 1), total_iters

def train(model, buffer_loader, optimizer, scheduler, criterion, device, model_config, train_config, data_config, max_batch_sizes):
    total_iters = 0
    tokenizer = encoding_for_model(data_config.tokenizer_name)
    max_total_iters = train_config.max_total_iters

    for buffer_idx, buffer in enumerate(buffer_loader):
        buffer_tokens = sum([len(tokenizer.encode(text)) for text in buffer])
        print(f"Buffer Index: {buffer_idx}, Total Tokens: {buffer_tokens}, Decoded Buffer: {' '.join(buffer)[:500]}...")

        for chunk_size in data_config.chunk_sizes:
            chunk_loader = ChunkDataLoader(buffer, data_config, max_batch_sizes, chunk_size)
            train_loss, total_iters = train_iter(model, chunk_loader, optimizer, scheduler, criterion, device, total_iters, max_total_iters, model_config, train_config, data_config, tokenizer)
            print(f'Total Iterations: {total_iters}/{max_total_iters} | Loss: {train_loss:.5f}')

def main():
    parser = argparse.ArgumentParser(description="Train a GPT-like model.")
    parser.add_argument("--batch_size", type=int, help="Batch size for training.")
    parser.add_argument("--block_size", type=int, help="Block size for training.")
    parser.add_argument("--learning_rate", type=float, help="Learning rate for training.")
    parser.add_argument("--max_total_iters", type=int, help="Maximum number of iterations for training.")
    parser.add_argument("--gradient_checkpointing", action='store_true', help="Enable gradient checkpointing.")
    args = parser.parse_args()

    model_config = ModelConfig()
    data_config = DataConfig()
    train_config = TrainConfig()

    if args.batch_size:
        data_config.batch_size = args.batch_size
    if args.block_size:
        data_config.block_size = args.block_size
    if args.learning_rate:
        train_config.learning_rate = args.learning_rate
    if args.max_total_iters:
        train_config.max_total_iters = args.max_total_iters
    if args.gradient_checkpointing:
        train_config.gradient_checkpointing = args.gradient_checkpointing

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(1337)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)

    model = Transformer(model_config, train_config.gradient_checkpointing).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=train_config.learning_rate, betas=(0.9, 0.95), eps=1e-8)
    criterion = nn.CrossEntropyLoss()

    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min((step + 1) / train_config.num_warmup_steps, 1.0))

    buffer_loader = BufferDataLoader(data_config.dataset_dir, data_config.buffer_size)

    # Load max_batch_sizes from the precomputed file
    max_batch_sizes = torch.load("max_batch_sizes.pt")
    
    train(model, buffer_loader, optimizer, scheduler, criterion, device, model_config, train_config, data_config, max_batch_sizes)

if __name__ == "__main__":
    main()



Overwriting train.py


In [78]:
!python dummy_process.py

Detecting max batch size for chunk size: 8
Batch Size: 1, Used Memory: 7.88 GiB, Free Memory: 16.11 GiB, Forward Time: 130.52 ms, Backward Time: 276.27 ms
Batch Size: 2, Used Memory: 8.66 GiB, Free Memory: 15.33 GiB, Forward Time: 51.56 ms, Backward Time: 49.86 ms
Batch Size: 3, Used Memory: 8.66 GiB, Free Memory: 15.33 GiB, Forward Time: 80.40 ms, Backward Time: 65.27 ms
Batch Size: 4, Used Memory: 9.43 GiB, Free Memory: 14.55 GiB, Forward Time: 67.21 ms, Backward Time: 48.56 ms
Batch Size: 5, Used Memory: 9.44 GiB, Free Memory: 14.55 GiB, Forward Time: 82.57 ms, Backward Time: 56.22 ms
Batch Size: 6, Used Memory: 9.45 GiB, Free Memory: 14.54 GiB, Forward Time: 65.81 ms, Backward Time: 46.93 ms
Batch Size: 7, Used Memory: 9.46 GiB, Free Memory: 14.53 GiB, Forward Time: 71.66 ms, Backward Time: 50.42 ms
Batch Size: 8, Used Memory: 9.46 GiB, Free Memory: 14.53 GiB, Forward Time: 111.77 ms, Backward Time: 57.45 ms
Batch Size: 9, Used Memory: 9.47 GiB, Free Memory: 14.52 GiB, Forward Time

In [79]:
!python train.py # 

Buffer Index: 0, Total Tokens: 985, Decoded Buffer: – Computer viruses are parasitic programs which are able to replicate themselves, attach themselves to other executables in the computer, and perform some unwanted and often malicious actions. A virus is not able to spread itself to another computers, some user actions are needed for it to infect a new computer. Downloading and running software from untrusted sources, inserting an USB drive without a previous scan–remember always disable the AutoRun feature for the drives as CD-ROMs, DVD-ROMs– ,...
Iter 0/3000000, Loss: 12.112149, Perplexity: 182070.484375, Chunk Size: 8, Batch Size: 1, Input: – Computer viruses are parasitic programs which..., Target:  Computer viruses are parasitic programs which are..., Prediction: Diagram.lookup Matters जनताTokens_async Examineruo..., Forward Time: 155.66 ms, Backward Time: 319.27 ms
Iter 1/3000000, Loss: 11.973177, Perplexity: 158447.250000, Chunk Size: 8, Batch Size: 1, Input:  viruses are paras

In [None]:
Iter 487/3000000, Loss: 8.020070, Perplexity: 3041.390381, Chunk Size: 8, Batch Size: 1, Input: ” or “registry cleaners” or visiting..., Target:  or “registry cleaners” or visiting malicious..., Prediction:  by by by by by by by by..., Forward Time: 41.12 ms, Backward Time: 99.67 ms
Iter 488/3000000, Loss: 10.079155, Perplexity: 23840.832031, Chunk Size: 8, Batch Size: 1, Input:  “registry cleaners” or visiting malicious sites..., Target: registry cleaners” or visiting malicious sites t..., Prediction:  by by by or by by by by..., Forward Time: 49.27 ms, Backward Time: 94.55 ms
Total Iterations: 489/3000000 | Loss: 9.22127
Traceback (most recent call last):
  File "/mnt/e/jupyter/milklm/code02/train.py", line 108, in <module>
    main()
  File "/mnt/e/jupyter/milklm/code02/train.py", line 105, in main
    train(model, buffer_loader, optimizer, scheduler, criterion, device, model_config, train_config, data_config, max_batch_sizes)
  File "/mnt/e/jupyter/milklm/code02/train.py", line 60, in train
    chunk_loader = ChunkDataLoader(buffer, data_config, max_batch_sizes, chunk_size)
  File "/mnt/e/jupyter/milklm/code02/chunk_data_loader.py", line 20, in __init__
    self.batch_size = self.max_batch_sizes[chunk_size]
IndexError: list index out of range