In [None]:
#!wget https://www.gutenberg.org/cache/epub/100/pg100.txt > data.txt
!pip uninstall tensorflow -y
!pip install tiktoken datasets tensorflow-cpu
!pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
!pip install wandb -qU

In [None]:
# Log in to your W&B account
import wandb
import random
import math

# Use wandb-core, temporary for wandb's new backend
wandb.require("core")
wandb.login(key="YOUR_KEY_HERE")

In [None]:
import tiktoken
import torch_xla as xla

tokenizer = tiktoken.get_encoding("r50k_base")

layers = 12
emb_size = 768
n_heads = 12 
dropout = 0.1
ctx_size = 2048
batch_size = 16
grad_accumulation_steps = 16
epochs = 1
# ds_size = 9508395
eval_batch_size = 48
vocab_size = tokenizer.n_vocab
wandb.init(
    project="gpt-2-pretrain",
    config={
        "epochs": epochs,
        "dataset": "FineWeb",
        "layers": layers,
        "emb_size": emb_size,
        "n_heads": n_heads,
        "dropout": dropout,
        "ctx_size": ctx_size,
        "batch_size": batch_size
    },
    # id="a92llyiu",
    # resume="must"
)

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import tiktoken
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
import torch_xla.runtime as xr
from torch_xla.amp import autocast, syncfree
import numpy as np

xr.use_spmd()


class LayerNorm(nn.Module):
    def __init__(self, n_dims):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(n_dims))
        self.bias = nn.Parameter(torch.zeros(n_dims))
    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias)


class Attention(nn.Module):
    def __init__(self, d_model, heads, p_dropout, c_size):
        super().__init__()
        self.d_model = d_model
        self.n_heads = heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.o = nn.Linear(d_model, d_model)
        self.o.type = "proj" # Needed for initialization
        self.dropout = nn.Dropout(p_dropout)
        mask = torch.tril(torch.ones(c_size, c_size))
        # Register the mask as a buffer but don't save it in state_dict to save memory
        self.register_buffer("bias", mask, persistent=False)
    def forward(self, x):
        B, T, C = x.size()
        q = self.q(x).view(B, T, self.n_heads, C//self.n_heads).transpose(1, 2)
        k = self.k(x).view(B, T, self.n_heads, C//self.n_heads).transpose(1, 2)
        v = self.v(x).view(B, T, self.n_heads, C//self.n_heads).transpose(1, 2)
        weights = q @ k.transpose(-2, -1)
        weights = weights / (self.d_model**0.5)
        mask = self.bias[:T,:T]
        weights = weights.masked_fill(mask == 0, float("-inf"))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        weights = weights @ v
        weights = weights.view((B, T, C))
        weights = self.o(weights)
        weights = self.dropout(weights)
        return weights

class Layer(nn.Module):
    def __init__(self, d_model, heads, p_dropout, c_size):
        super().__init__()
        self.layerNorm1 = LayerNorm(d_model)
        self.attention = Attention(d_model, heads, p_dropout, c_size)
        self.layerNorm2 = LayerNorm(d_model)
        self.mlp = nn.Linear(d_model, 4*d_model)
        self.mlp_proj = nn.Linear(d_model*4, d_model)
        self.mlp_proj.type = "proj" # Needed for initialization
        self.dropout = nn.Dropout(p_dropout)
    def forward(self, x):
        x = x + self.attention(self.layerNorm1(x))
        l = self.mlp(self.layerNorm2(x))
        l = F.gelu(l)
        l = self.mlp_proj(l)
        l = self.dropout(l)
        x = x + l
        return x

class Transformer(nn.Module):
    def __init__(self, n_layers, d_model, heads, p_dropout, v_size, c_size):
        super().__init__()
        self.wte = nn.Embedding(v_size, d_model)
        self.wpe = nn.Embedding(c_size, d_model)
        # Needed because wte and wpe are initialized differently
        self.wte.type = "wte"
        self.wpe.type = "wpe"
        self.layers = nn.ModuleList([Layer(d_model, heads, p_dropout, c_size) for _ in range(n_layers)])
        self.dropout = nn.Dropout(p_dropout)
        self.ln = LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, v_size, bias=False)
        self.wte.weight = self.lm_head.weight # Tying weights improves performance and reduces learnable params
        self.apply(self._init_weights)
    def _init_weights(self, module):
        # Weight initialization per the GPT-2 code
        # https://github.com/openai/gpt-2/blob/master/src/model.py
        if isinstance(module, nn.Linear):
            if hasattr(module, "type"):
                std = 0.02/(2*len(self.layers))**0.5
            else:
                std = 0.02
            nn.init.normal_(module.weight, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        if isinstance(module, nn.Embedding):
            if module.type == "wte":
                nn.init.normal_(module.weight, std=0.02)
            else:
                #wpe is initialized with stddev of 0.01
                nn.init.normal_(module.weight, std=0.01)
    def forward(self, x, y=None):
        B, T = x.shape
        positions = torch.arange(0, T, dtype=torch.long).to(x.device)
        token_embs = self.wte(x)
        pos_embs = self.wpe(positions)
        x = token_embs + pos_embs
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x)
        x = self.ln(x)
        x = self.lm_head(x)
        if y is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1))
        else:
            loss = None
        return x, loss
    def generate(self, x, max_tokens):
        for i in range(max_tokens):
            logits, _ = self(torch.stack([x]))
            logits = logits[:, -1, :]
            logits = F.softmax(logits, dim=-1)
            out = torch.multinomial(logits, 1)
            out = out.view(-1)
            x = torch.cat([x, out], dim=-1)
        return x

def create_data_loader(batch_size, ctx_size, shard_len, shard_dir, num_shards, is_training=True):
    # Create a simple dataset that generates inputs based on indices
    class TokenDataset(torch.utils.data.Dataset):
        def __init__(self, ctx_size, shard_len, shard_dir, num_shards):
            self.ctx_size = ctx_size
            self.shard_len = shard_len
            self.shard_dir = shard_dir
            self.num_shards = num_shards
        def __len__(self):
            return self.shard_len*self.num_shards - self.ctx_size
            
        def __getitem__(self, idx):
            if idx % self.shard_len + self.ctx_size > self.shard_len:
                #Spans across 2 shards
                shard1 = idx//self.shard_len
                shard2 = idx//self.shard_len+1
                path1 = self.shard_dir+"/"+sorted(os.listdir(self.shard_dir))[shard1]
                path2 = self.shard_dir+"/"+sorted(os.listdir(self.shard_dir))[shard2]
                data_shard_1 = np.load(path1, mmap_mode="r")
                data_shard_2 = np.load(path2, mmap_mode="r")
                start_pos = idx - (shard1)*self.shard_len# Start position on shard 1
                end_pos = self.ctx_size - (self.shard_len - start_pos)# End position on shard 2
                x = torch.tensor(np.concatenate([data_shard_1[start_pos:], data_shard_2[:end_pos]]), dtype=torch.long)
                y = torch.tensor(np.concatenate([data_shard_1[start_pos+1:], data_shard_2[:end_pos+1]]), dtype=torch.long)
                del data_shard_1
                del data_shard_2
            else:
                current_shard = idx//self.shard_len
                path = self.shard_dir+"/"+sorted(os.listdir(self.shard_dir))[current_shard]
                shard = np.load(path, mmap_mode="r")
                start_pos = idx - (current_shard)*self.shard_len
                x = torch.tensor(shard[start_pos:start_pos+self.ctx_size], dtype=torch.long)
                y = torch.tensor(shard[start_pos+1:start_pos+self.ctx_size+1], dtype=torch.long)
                del shard
            return x, y
    
    dataset = TokenDataset(ctx_size, shard_len, shard_dir, num_shards)
    sampler = torch.utils.data.SequentialSampler(
        dataset
    )
    
    # Create the dataloader with the sampler
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        drop_last=is_training,
        num_workers=20,  # TPU has better performance with in-process data loading
    )
    
    return loader, sampler


def train(tokenizer, layers, emb_size, n_heads, dropout, vocab_size, ctx_size, epochs, batch_size, eval_batch_size, steps, gradient_accumulation_steps):
    device = xla.device()
    
    model = Transformer(layers, emb_size, n_heads, dropout, vocab_size, ctx_size).to(device)
    lr = 6e-4
    #SPMD Stuff
    num_devices = xr.global_runtime_device_count()
    mesh_shape = (num_devices, 1)
    device_ids = np.array(range(num_devices))
    mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
    model = FSDPv2(module=model, mesh=mesh)
    # Performance gain
    # model = torch.compile(model, backend='openxla')

    # Using GPT-3 hyperparams for betas and weight decay
    optimizer = syncfree.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps*0.99, eta_min=6e-5)
    warmup_steps = steps*0.01
    # Set up train loaders
    train_loader, train_sampler = create_data_loader(batch_size, ctx_size, 200000000, "/kaggle/input/fineweb-edu-10bt-for-gpt2/train", 49, is_training=True)
    eval_loader, _ = create_data_loader(eval_batch_size, ctx_size, 150000000, "/kaggle/input/fineweb-edu-10bt-for-gpt2/test", 1, is_training=False)
    # Wrap train loader with parallel loader for efficient TPU usage
    train_device_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('fsdp', None)))
    eval_device_loader = pl.MpDeviceLoader(eval_loader, device, input_sharding=xs.ShardingSpec(mesh, ('fsdp', None)))
    # checkpoint = torch.load("/kaggle/input/")
    # model.module.load_state_dict(checkpoint['model'])
    # optimizer.load_state_dict(checkpoint['optim'])
    i = 0
    for j in range(epochs):
        train_iter = iter(train_device_loader)
        #for step, (x, y) in enumerate(train_device_loader):
        for step in range(steps):
            # if i <= checkpoint['step']:
            #     if i < warmup_steps:
            #         lr_scale = (i + 1) / warmup_steps
            #         for param_group in optimizer.param_groups:
            #             param_group["lr"] = lr * lr_scale
            #     else:
            #         scheduler.step()
            #     continue
            model.train()
            train_loss = 0
            for k in range(gradient_accumulation_steps):
                x, y = next(train_iter)
                # Forward pass
                with autocast(xm.xla_device(), dtype=torch.bfloat16):
                    _, loss = model(x, y)
                train_loss += loss.item()
                loss = loss/gradient_accumulation_steps
                
                # Backward pass and optimization
                loss.backward()
            train_loss /= gradient_accumulation_steps
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            xm.optimizer_step(optimizer)
            optimizer.zero_grad()
            # Log on master process only
            if xm.is_master_ordinal():
                xm.master_print(f"Step {i}, Loss: {train_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
            
            # Learning rate scheduling
            if i < warmup_steps:
                lr_scale = (i + 1) / warmup_steps
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr * lr_scale
            else:
                scheduler.step()
            
            # Evaluation and checkpointing
            if step % 300 == 0:
                model.eval()
                
                # Generate sample text (only on master process)
                prompt = torch.tensor(tokenizer.encode("<|endoftext|>Once upon a time,", allowed_special={'<|endoftext|>'})[:50], device=device)
                with torch.no_grad():
                    generated = model.module.generate(prompt, max_tokens=20)
                decoded = tokenizer.decode(generated.cpu().tolist())
                xm.master_print(f"Sample generation:\n{decoded}")
                
                # Evaluate on validation set
                total_loss = 0.0
                count = 0
                for eval_x, eval_y in eval_device_loader:
                    with torch.no_grad():
                        _, loss_eval = model(eval_x, eval_y)
                    total_loss += loss_eval.item()
                    count += 1
                    if count == 50:
                        break
                # Average loss
                avg_loss = total_loss / count if count > 0 else 0.0
                wandb.log({"train_loss": train_loss, "val_loss": avg_loss, "lr": optimizer.param_groups[0]['lr'], "sample": decoded})
                xm.master_print(f"Step {i}, Test loss: {avg_loss:.4f}")
                # Save checkpoint
                state_dict = {
                    "model": model.module.state_dict(),
                    "optim": optimizer.state_dict(),
                    "step": i
                }
                torch.save(state_dict, f"./ckpt-{i}.pt")
                xm.master_print(f"Checkpoint saved to ./ckpt-{i}.pt")
            else:
                wandb.log({"train_loss": train_loss, "lr": optimizer.param_groups[0]['lr']})
            i+=1
    # Save checkpoint at very end
    state_dict = {
        "model": model.module.state_dict(),
        "optim": optimizer.state_dict(),
        "step": i
    }
    torch.save(state_dict, f"./ckpt-{i}.pt")
    xm.master_print(f"FINAL Checkpoint saved to ./ckpt-{i}.pt")

In [None]:
import os
os.environ.pop('TPU_PROCESS_ADDRESSES')

In [None]:
steps = epochs*((200000000*49 - ctx_size) // (batch_size*grad_accumulation_steps*ctx_size))
print(f"Total steps: {steps}")

In [None]:
train(tokenizer, layers, emb_size, n_heads, dropout, vocab_size, ctx_size, epochs, batch_size, eval_batch_size, steps, grad_accumulation_steps)