# Training Modern Transformer from Scratch on 80GB A100 GPU

This notebook implements and trains a **modern, optimized transformer** from scratch using PyTorch, incorporating state-of-the-art improvements over GPT-2:

**Modern Architecture Features:**
- **Flash Attention** (memory-efficient scaled dot-product attention)
- **RMSNorm** instead of LayerNorm (faster, more stable)
- **SwiGLU activation** instead of GELU (better performance)
- **Rotary Position Embeddings (RoPE)** instead of learned positional embeddings
- **Grouped Query Attention (GQA)** for efficiency
- **Optimized for 80GB A100** with mixed precision training

This is a **custom, improved architecture** designed for maximum performance on modern hardware, not a recreation of existing models.

In [None]:
# Install dependencies
!pip install torch --index-url https://download.pytorch.org/whl/cu121
!pip install transformers datasets tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from datasets import load_dataset
from transformers import GPT2Tokenizer
from tqdm import tqdm

In [None]:
class TransformerConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, n_kv_head=None, dropout=0.0):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        # Grouped Query Attention: fewer KV heads than Q heads
        self.n_kv_head = n_kv_head if n_kv_head is not None else n_head
        self.dropout = dropout
        self.head_dim = n_embd // n_head

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (more efficient than LayerNorm)"""
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

def apply_rotary_emb(x, cos, sin):
    """Apply Rotary Position Embeddings"""
    # x: (batch, seq_len, n_head, head_dim)
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    # Rotate
    rotated = torch.cat([-x2, x1], dim=-1)
    return (x * cos) + (rotated * sin)

def precompute_freqs_cis(dim, max_seq_len, theta=10000.0):
    """Precompute RoPE frequencies"""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs).float()
    cos = torch.cos(freqs)
    sin = torch.sin(freqs)
    return cos, sin

class GroupedQueryAttention(nn.Module):
    """Multi-Query Attention with Grouped Queries (GQA) for efficiency"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = config.head_dim
        self.n_rep = self.n_head // self.n_kv_head  # repetition factor for KV heads
        
        # Q projection for all heads, K/V for fewer heads
        self.wq = nn.Linear(config.n_embd, config.n_head * config.head_dim, bias=False)
        self.wk = nn.Linear(config.n_embd, config.n_kv_head * config.head_dim, bias=False)
        self.wv = nn.Linear(config.n_embd, config.n_kv_head * config.head_dim, bias=False)
        self.wo = nn.Linear(config.n_head * config.head_dim, config.n_embd, bias=False)
        
        self.dropout = nn.Dropout(config.dropout)
        
        # Causal mask
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                           .view(1, 1, config.block_size, config.block_size))

    def forward(self, x, freqs_cos, freqs_sin):
        B, T, C = x.size()
        
        # Project and reshape
        q = self.wq(x).view(B, T, self.n_head, self.head_dim)
        k = self.wk(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.wv(x).view(B, T, self.n_kv_head, self.head_dim)
        
        # Apply RoPE
        q = apply_rotary_emb(q, freqs_cos[:T], freqs_sin[:T])
        k = apply_rotary_emb(k, freqs_cos[:T], freqs_sin[:T])
        
        # Repeat KV heads to match Q heads (for GQA)
        if self.n_rep > 1:
            k = k.repeat_interleave(self.n_rep, dim=2)
            v = v.repeat_interleave(self.n_rep, dim=2)
        
        # Transpose for attention computation
        q = q.transpose(1, 2)  # (B, n_head, T, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Scaled dot-product attention with Flash Attention if available
        if hasattr(F, 'scaled_dot_product_attention'):
            # Use PyTorch's optimized Flash Attention
            y = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.dropout.p if self.training else 0.0,
                is_causal=True
            )
        else:
            # Fallback to manual attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
            att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.dropout(att)
            y = att @ v
        
        # Reshape and project output
        y = y.transpose(1, 2).contiguous().view(B, T, -1)
        y = self.wo(y)
        return y

class SwiGLU(nn.Module):
    """SwiGLU activation (better than GELU for language models)"""
    def __init__(self, config):
        super().__init__()
        hidden_dim = int(8 * config.n_embd / 3)  # Common FFN expansion
        hidden_dim = ((hidden_dim + 255) // 256) * 256  # Round to multiple of 256 for efficiency
        
        self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config.n_embd, bias=False)
        self.w3 = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        # SwiGLU: swish(W1¬∑x) ‚äô (W3¬∑x) then project with W2
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention_norm = RMSNorm(config.n_embd)
        self.attention = GroupedQueryAttention(config)
        self.ffn_norm = RMSNorm(config.n_embd)
        self.ffn = SwiGLU(config)

    def forward(self, x, freqs_cos, freqs_sin):
        # Pre-norm architecture with residual connections
        x = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin)
        x = x + self.ffn(self.ffn_norm(x))
        return x

class ModernTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token embeddings only (no learned positional embeddings - using RoPE instead)
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        
        # Precompute RoPE frequencies
        freqs_cos, freqs_sin = precompute_freqs_cis(config.head_dim, config.block_size)
        self.register_buffer("freqs_cos", freqs_cos)
        self.register_buffer("freqs_sin", freqs_sin)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
        self.norm = RMSNorm(config.n_embd)
        
        # Output projection
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying
        self.token_emb.weight = self.lm_head.weight
        
        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
        
        # Token embeddings (no positional - using RoPE in attention)
        x = self.token_emb(idx)
        x = self.dropout(x)
        
        # Transform through blocks
        freqs_cos = self.freqs_cos[:T].unsqueeze(0).unsqueeze(2)
        freqs_sin = self.freqs_sin[:T].unsqueeze(0).unsqueeze(2)
        
        for block in self.blocks:
            x = block(x, freqs_cos, freqs_sin)
        
        x = self.norm(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=50):
        for _ in range(max_new_tokens):
            # Crop context if needed
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            
            # Forward pass
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            
            # Top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        
        return idx

In [None]:
def get_data(batch_size=8, block_size=1024):
    # Load dataset
    print("Downloading Wikitext dataset... This may take a few minutes.")
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

    # Load tokenizer
    print("Loading GPT-2 tokenizer...")
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    # Concatenate all texts
    print("Concatenating texts...")
    train_text = "\n\n".join(dataset['train']['text'])
    val_text = "\n\n".join(dataset['validation']['text'])

    # Tokenize
    print("Tokenizing data... This may take some time.")
    train_tokens = tokenizer.encode(train_text)
    val_tokens = tokenizer.encode(val_text)

    # Chunk into blocks
    print("Chunking tokens into blocks...")
    def chunk_tokens(tokens, block_size):
        chunks = []
        for i in range(0, len(tokens) - block_size + 1, block_size):
            chunks.append(tokens[i:i + block_size])
        return chunks

    train_chunks = chunk_tokens(train_tokens, block_size)
    val_chunks = chunk_tokens(val_tokens, block_size)

    # Convert to tensors
    print("Converting to tensors...")
    train_data = torch.tensor(train_chunks, dtype=torch.long)
    val_data = torch.tensor(val_chunks, dtype=torch.long)

    # Create data loaders
    print("Creating data loaders...")
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

    print("Data preparation complete!")
    return train_loader, val_loader, tokenizer

In [None]:
def train(model, train_loader, val_loader, optimizer, scheduler, device, epochs=1, grad_accum_steps=1, use_amp=True):
    scaler = torch.cuda.amp.GradScaler() if use_amp else None
    model.train()
    
    print(f"\nStarting training on {device}")
    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Initial GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
        print(f"Initial GPU memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB\n")
    
    for epoch in range(epochs):
        total_loss = 0
        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            # Handle both tensor and dict inputs
            if isinstance(batch, dict):
                input_ids = batch['input_ids'].to(device)
            else:
                input_ids = batch.to(device)
                
            targets = input_ids.clone()
            targets[:, :-1] = input_ids[:, 1:]
            targets[:, -1] = -1  # ignore last token

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits, loss = model(input_ids, targets)
                loss = loss / grad_accum_steps

            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if (i + 1) % grad_accum_steps == 0:
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

            total_loss += loss.item() * grad_accum_steps
            
            # Print memory usage every 100 steps
            if device.type == 'cuda' and (i + 1) % 100 == 0:
                print(f"\nStep {i+1}: GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

        avg_loss = total_loss / len(train_loader)
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_loss:.4f}")
        
        if device.type == 'cuda':
            print(f"Peak GPU memory allocated: {torch.cuda.max_memory_allocated(0) / 1024**3:.2f} GB")
            print(f"Peak GPU memory reserved: {torch.cuda.max_memory_reserved(0) / 1024**3:.2f} GB")

        # Validation
        print(f"\nRunning validation...")
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                # Handle both tensor and dict inputs
                if isinstance(batch, dict):
                    input_ids = batch['input_ids'].to(device)
                else:
                    input_ids = batch.to(device)
                    
                targets = input_ids.clone()
                targets[:, :-1] = input_ids[:, 1:]
                targets[:, -1] = -1
                _, loss = model(input_ids, targets)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} - Val Loss: {avg_val_loss:.4f}")
        print(f"{'='*60}\n")
        model.train()

        # Save checkpoint
        checkpoint_path = f'checkpoint_epoch_{epoch+1}.pt'
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}\n")

In [None]:
# Configuration for 80GB A100 GPU - Optimized for Maximum Performance
print("="*60)
print("Configuration for 80GB A100 GPU - Training Large Model")
print("="*60)

# Choose model size: 'medium' (355M), 'large' (774M), or 'xl' (1.5B)
MODEL_SIZE = 'large'  # Change to 'xl' for even bigger model

if MODEL_SIZE == 'medium':
    # GPT-2 Medium (355M parameters)
    n_layer = 24
    n_head = 16
    n_embd = 1024
    batch_size = 12
    grad_accum_steps = 8
    print("\nüöÄ Training GPT-2 Medium (355M parameters)")
    
elif MODEL_SIZE == 'large':
    # GPT-2 Large (774M parameters)
    n_layer = 36
    n_head = 20
    n_embd = 1280
    batch_size = 6
    grad_accum_steps = 12
    print("\nüöÄ Training GPT-2 Large (774M parameters)")
    
elif MODEL_SIZE == 'xl':
    # GPT-2 XL (1.5B parameters)
    n_layer = 48
    n_head = 25
    n_embd = 1600
    batch_size = 4
    grad_accum_steps = 16
    print("\nüöÄ Training GPT-2 XL (1.5B parameters)")

block_size = 1024
epochs = 1
lr = 3e-4
use_amp = True  # Mixed precision training for A100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    
# Estimated memory usage
print(f"\nModel Configuration:")
print(f"  Layers: {n_layer}, Heads: {n_head}, Embedding: {n_embd}")
print(f"  Batch size: {batch_size}, Block size: {block_size}")
print(f"  Gradient accumulation steps: {grad_accum_steps}")
print(f"  Effective batch size: {batch_size * grad_accum_steps}")

# Calculate approximate model size
params = 12 * n_layer * n_embd**2 + 2 * n_embd * 50257 + 2 * n_embd * block_size
model_size_gb = params * 4 / 1024**3  # 4 bytes per float32
optimizer_size_gb = params * 8 / 1024**3  # AdamW uses 2x model size (first and second moments)
gradients_size_gb = params * 4 / 1024**3
activations_gb = batch_size * block_size * n_embd * n_layer * 4 / 1024**3

total_estimated_gb = model_size_gb + optimizer_size_gb + gradients_size_gb + activations_gb

print(f"\nüìä Memory Estimates:")
print(f"  Model parameters: ~{model_size_gb:.2f} GB")
print(f"  Optimizer states: ~{optimizer_size_gb:.2f} GB")
print(f"  Gradients: ~{gradients_size_gb:.2f} GB")
print(f"  Activations (approx): ~{activations_gb:.2f} GB")
print(f"  Total estimated: ~{total_estimated_gb:.2f} GB")
print(f"  Available: 80 GB")
print(f"  Safety margin: ~{80 - total_estimated_gb:.2f} GB")

if total_estimated_gb > 70:
    print(f"\n‚ö†Ô∏è  Warning: High memory usage. Consider reducing batch_size or grad_accum_steps.")
else:
    print(f"\n‚úÖ Configuration should fit comfortably on an 80GB A100 GPU.")
print("="*60 + "\n")

# Model initialization
print("Initializing model...")
config = GPTConfig(vocab_size=50257, block_size=block_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, dropout=0.1)
model = GPT(config).to(device)

# Count actual parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")

if device.type == 'cuda':
    print(f"Model loaded. GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB\n")

# Data loading
print("Loading and preparing data...")
train_loader, val_loader, tokenizer = get_data(batch_size=batch_size, block_size=block_size)
print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}\n")

# Optimizer and scheduler
print("Setting up optimizer and scheduler...")
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.95))
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * epochs // grad_accum_steps)
print("Setup complete!\n")

# Train
train(model, train_loader, val_loader, optimizer, scheduler, device, epochs, grad_accum_steps, use_amp)

In [None]:
# Generation
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, top_k=50):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        output = model.generate(input_ids, max_new_tokens, temperature=temperature, do_sample=True, top_k=top_k)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

# Example
prompt = "The future of AI is"
generated = generate_text(model, tokenizer, prompt)
print(generated)

## Instructions

1. **Install dependencies**: Run cell 2 to install PyTorch, transformers, datasets, and tqdm.
2. **Import libraries**: Run cell 3 to import all required libraries.
3. **Load model architecture**: Run cell 4 to define the GPT model classes.
4. **Load data preparation function**: Run cell 5 to define the data loading function.
5. **Load training function**: Run cell 6 to define the training loop.
6. **Configure and train**: Run cell 7 to configure the model size and start training.
   - **Model sizes available:**
     - `'medium'`: 355M parameters (~15-20 GB total memory)
     - `'large'`: 774M parameters (~30-35 GB total memory) - **Recommended**
     - `'xl'`: 1.5B parameters (~55-65 GB total memory) - Maximum capacity
7. **Generate text**: After training completes, run cell 8 to test text generation.

## Model Size Selection

The configuration automatically selects appropriate batch sizes and gradient accumulation steps for each model size to maximize utilization of the 80GB A100 GPU:

- **GPT-2 Medium** (355M): Fast training, good quality
- **GPT-2 Large** (774M): Best balance of speed and quality - **recommended for most use cases**
- **GPT-2 XL** (1.5B): Highest quality, slower training, pushes GPU to limits

To change model size, edit the `MODEL_SIZE` variable in cell 7 before running it.

## Training Tips

- Training will take several hours depending on model size
- Checkpoints are saved after each epoch
- GPU memory usage is monitored and displayed during training
- Mixed precision (FP16) is enabled for faster training on A100
- If you encounter OOM errors, reduce `batch_size` or increase `grad_accum_steps`