# Training a Decoder-Only Transformer Model

This notebook implements training for the transformer model defined in transformer.py

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
from pathlib import Path
from transformer import Config, DecoderOnlyTransformer
from torch.nn import functional as F
import numpy as np
from tqdm.notebook import tqdm

In [None]:
from ipywidgets import Widget
Widget.widgets.clear_state()
Widget.widgets.save_state()

## Configuration

In [6]:
# Training hyperparameters
BATCH_SIZE = 48
LEARNING_RATE = 3e-4
NUM_EPOCHS = 2
DATA_DIR = "data/"  # Directory containing text files
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model configuration
config = Config(
    vocab_size=50257,
    max_seq_len=256,
    dim=768,
    num_layers=8,
    num_heads=8,
    dropout=0.1
)

## Data Processing

In [7]:
class TextDataset(Dataset):
    def __init__(self, data_dir, seq_length):
        self.data_dir = Path(data_dir)
        self.seq_length = seq_length
        self.files = list(self.data_dir.glob('*.txt'))
        
        # Simple tokenization (character-level for this example)
        self.chars = sorted(list(set(''.join(open(f).read() for f in self.files))))
        self.vocab_size = len(self.chars)
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        
        # Load all text
        self.text = ''
        for file in self.files:
            with open(file, 'r') as f:
                self.text += f.read()
        
        # Convert text to indices
        self.data = torch.tensor([self.char_to_idx[ch] for ch in self.text], dtype=torch.long)
    
    def __len__(self):
        return len(self.data) - self.seq_length
    
    def __getitem__(self, idx):
        # Get sequence and target
        x = self.data[idx:idx + self.seq_length]
        y = self.data[idx + 1:idx + self.seq_length + 1]
        return x, y

## Model Setup

In [8]:
# Initialize model
model = DecoderOnlyTransformer(config).to(DEVICE)

# Calculate and print model size
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_params = count_parameters(model)
print(f'Number of parameters: {num_params:,}')

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

Number of parameters: 134,095,872


## Training Loop

In [9]:
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc='Training')
    for batch_idx, (x, y) in enumerate(progress_bar):
        # Move data to device
        x = x.to(device)
        y = y.to(device)
        
        # Forward pass
        logits = model(x)
        
        # Compute loss
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': total_loss / (batch_idx + 1)})
    
    return total_loss / len(dataloader)

## Training Execution

In [10]:
# Create dataset and dataloader
dataset = TextDataset(DATA_DIR, config.max_seq_len)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Training loop
for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    avg_loss = train_epoch(model, dataloader, optimizer, DEVICE)
    print(f'Average loss: {avg_loss:.4f}')
    
    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, f'checkpoint_epoch_{epoch+1}.pt')


Epoch 1/2


Training:   0%|          | 0/23233 [00:00<?, ?it/s]

Average loss: 0.0455

Epoch 2/2


Training:   0%|          | 0/23233 [00:00<?, ?it/s]

Average loss: 0.0072


## Generate Text (Inference)

In [11]:
def generate_text(model, dataset, start_text, max_length=100, temperature=1.0):
    model.eval()
    
    # Convert start text to indices
    context = torch.tensor([dataset.char_to_idx[ch] for ch in start_text], dtype=torch.long)
    context = context.unsqueeze(0).to(DEVICE)  # Add batch dimension
    
    generated = list(start_text)
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get predictions
            logits = model(context)
            logits = logits[:, -1, :] / temperature
            
            # Sample from the distribution
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Add the predicted token to the sequence
            generated.append(dataset.idx_to_char[next_token.item()])
            context = torch.cat([context, next_token], dim=1)
    
    return ''.join(generated)

# Example text generation
start_text = "Once upon a time"
generated_text = generate_text(model, dataset, start_text)
print(generated_text)

Once upon a time ceetiiit niittiietctttinute mtiieintttatnain ueautttat titiiettt tucetitinutanttuneitinintiitttinii
