# nanoEBM: Complete Training & Inference Pipeline on H100

This notebook provides a full-scale workflow for Energy-Based Models (EBM) training, inference, and visualization.

## Sections:
1. Setup & Environment
2. Clone Repository & Install Dependencies
3. Data Preparation
4. Model Training
   - Baseline (no thinking)
   - With iterative refinement (thinking)
5. Inference & Sampling
   - Greedy generation
   - Thinking + nucleus sampling
6. Visualization & Analysis
7. Metrics Inspection

---
## 1. Setup & Environment
Verify GPU availability and environment setup.

In [None]:
import torch
import os
import sys
from pathlib import Path

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
    print(f"Device capability: {torch.cuda.get_device_capability(0)}")
    
    # H100 check (compute capability 9.0)
    if torch.cuda.get_device_capability(0)[0] >= 9:
        print("✓ H100 detected!")
    
    # Memory info
    mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Total GPU memory: {mem_gb:.1f} GB")

---
## 2. Clone Repository & Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/sdan/nanoEBM.git
%cd nanoEBM

In [None]:
# Verify we're in the right directory
!pwd
!ls -la

In [None]:
# Install dependencies using pip (since we're in a notebook environment)
# If using uv, you can uncomment the uv commands below

# Option 1: Using pip
!pip install torch numpy chz tiktoken pyarrow matplotlib wandb

# Option 2: Using uv (if available)
# !pip install uv
# !uv sync

In [None]:
# Import the nanoebm package
from nanoebm.config import Config, ModelConfig, DataConfig, TrainConfig
from nanoebm.model import EBTLanguageModel
from nanoebm.data import get_loader
from nanoebm.utils import Logger

print("✓ nanoEBM imports successful")

---
## 3. Data Preparation

### Option A: Use built-in Shakespeare dataset (character-level)

In [None]:
# Check if shakespeare.txt exists
shakespeare_path = Path("shakespeare.txt")
if shakespeare_path.exists():
    print(f"✓ Shakespeare dataset found")
    print(f"  Size: {shakespeare_path.stat().st_size / 1e6:.2f} MB")
    
    # Preview first 500 characters
    with open(shakespeare_path, 'r') as f:
        preview = f.read(500)
    print(f"\nPreview:\n{preview}...")
else:
    print("✗ Shakespeare dataset not found. Downloading...")
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O shakespeare.txt

### Option B: Prepare BPE dataset (for production)

For larger-scale training, you can use BPE tokenization with parquet datasets.

In [None]:
# Example: Download and prepare a larger dataset (optional)
# Uncomment if you want to use a larger dataset

# DATA_DIR = Path("data/fineweb_sample")
# DATA_DIR.mkdir(parents=True, exist_ok=True)

# # Example: Download sample from HuggingFace datasets
# # You'll need to prepare parquet files or use a text file
# print("For BPE training, ensure you have parquet files or a large text corpus")

---
## 4. Model Training

### 4.1 Training Configuration

We'll set up configurations for different training scenarios.

In [None]:
# Baseline Configuration (No thinking, character-level)
baseline_cfg = Config(
    model=ModelConfig(
        vocab_size=None,  # Auto-detected from dataset
        block_size=256,
        n_layer=6,
        n_head=6,
        n_embd=384,
        dropout=0.1,
        mcmc_num_steps=0,  # No iterative refinement
    ),
    data=DataConfig(
        dataset="shakespeare",
        data_path="shakespeare.txt",
        tokenizer="char",
        block_size=256,
        batch_size=64,
    ),
    train=TrainConfig(
        max_steps=5000,
        learning_rate=3e-4,
        eval_interval=500,
        log_interval=10,
        save_interval=1000,
        compile_model=True,  # Use torch.compile for speed
    ),
    device="cuda",
    seed=1337,
)

print("✓ Baseline config created")
print(f"  Training steps: {baseline_cfg.train.max_steps}")
print(f"  Batch size: {baseline_cfg.data.batch_size}")
print(f"  Model layers: {baseline_cfg.model.n_layer}")

In [None]:
# Advanced Configuration (With thinking, character-level)
thinking_cfg = Config(
    model=ModelConfig(
        vocab_size=None,  # Auto-detected
        block_size=256,
        n_layer=6,
        n_head=6,
        n_embd=384,
        dropout=0.1,
        # Energy-based thinking parameters
        mcmc_num_steps=2,  # 2 refinement steps
        mcmc_step_size=1.0,
        mcmc_step_size_learnable=True,  # Learn alpha
        entropy_reg_tau=0.5,
        truncate_mcmc=True,  # Don't backprop through all steps
        no_mcmc_detach=False,
    ),
    data=DataConfig(
        dataset="shakespeare",
        data_path="shakespeare.txt",
        tokenizer="char",
        block_size=256,
        batch_size=64,
    ),
    train=TrainConfig(
        max_steps=10000,
        learning_rate=3e-4,
        eval_interval=500,
        log_interval=10,
        save_interval=1000,
        compile_model=True,
    ),
    device="cuda",
    seed=1337,
)

print("✓ Thinking config created")
print(f"  MCMC steps: {thinking_cfg.model.mcmc_num_steps}")
print(f"  Learnable step size: {thinking_cfg.model.mcmc_step_size_learnable}")
print(f"  Training steps: {thinking_cfg.train.max_steps}")

### 4.2 Train Baseline Model

Train a standard EBM without iterative refinement.

In [None]:
# Train baseline using subprocess (cleaner output)
# This trains for 5000 steps with character-level tokenization

!python train.py \
    data.data_path=shakespeare.txt \
    data.tokenizer=char \
    model.mcmc_num_steps=0 \
    train.max_steps=5000 \
    train.log_interval=10 \
    train.eval_interval=500 \
    train.compile_model=true \
    out_dir=out_ebt/baseline

### 4.3 Train Model with Thinking

Train with iterative refinement (2 MCMC steps).

In [None]:
# Train with thinking for 10000 steps
!python train.py \
    data.data_path=shakespeare.txt \
    data.tokenizer=char \
    model.mcmc_num_steps=2 \
    model.mcmc_step_size_learnable=true \
    model.truncate_mcmc=true \
    train.max_steps=10000 \
    train.log_interval=10 \
    train.eval_interval=500 \
    train.compile_model=true \
    out_dir=out_ebt/thinking

### 4.4 Optional: Train BPE Model (for production)

For larger-scale training with BPE tokenization.

In [None]:
# Uncomment to train with BPE tokenization
# Requires a larger dataset (parquet or large text file)

# !python train.py \
#     data.tokenizer=gpt2 \
#     data.data_path=/path/to/large/dataset \
#     model.vocab_size=50304 \
#     model.mcmc_num_steps=2 \
#     train.max_steps=50000 \
#     train.batch_size=32 \
#     out_dir=out_ebt/bpe_thinking

---
## 5. Inference & Sampling

### 5.1 Greedy Generation (Baseline)

In [None]:
# Generate text using greedy decoding (no thinking)
# Uses the baseline checkpoint

!python sample.py \
    checkpoint=out_ebt/baseline/final.pt \
    prompt="ROMEO:" \
    max_new_tokens=300

### 5.2 Generation with Thinking (Iterative Refinement)

In [None]:
# Generate with thinking (4 refinement steps per token)
# Uses nucleus sampling for diversity

!python sample.py \
    checkpoint=out_ebt/thinking/final.pt \
    prompt="HAMLET:" \
    max_new_tokens=300 \
    use_thinking=true \
    think_steps=4 \
    topk=64 \
    sample=true \
    sample_temp=1.0 \
    sample_top_p=0.95

### 5.3 Interactive Sampling (Python API)

In [None]:
# Load model for interactive generation
import torch
from nanoebm.model import EBTLanguageModel
from nanoebm.config import Config
import json

# Load checkpoint
ckpt_path = "out_ebt/thinking/final.pt"
checkpoint = torch.load(ckpt_path, map_location="cuda")

# Load config
config = Config(**checkpoint["config"])

# Create model
model = EBTLanguageModel(
    config.model.to_gpt_config(),
    config.model.to_ebt_config()
)
model.load_state_dict(checkpoint["model"])
model = model.to("cuda")
model.eval()

print(f"✓ Model loaded from {ckpt_path}")
print(f"  Vocab size: {config.model.vocab_size}")
print(f"  Block size: {config.model.block_size}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

In [None]:
# Interactive generation function
from nanoebm.data import CharDataset, BPEDataset
import tiktoken

def generate_text(prompt, max_tokens=200, temperature=1.0, top_p=0.95):
    """Generate text from a prompt."""
    
    # Initialize tokenizer
    if config.data.tokenizer == "char":
        # Load vocab from dataset
        with open(config.data.data_path, 'r') as f:
            text = f.read()
        chars = sorted(list(set(text)))
        stoi = {ch: i for i, ch in enumerate(chars)}
        itos = {i: ch for i, ch in enumerate(chars)}
        encode = lambda s: [stoi[c] for c in s]
        decode = lambda l: ''.join([itos[i] for i in l])
    else:
        enc = tiktoken.get_encoding(config.data.bpe_encoding)
        encode = enc.encode
        decode = enc.decode
    
    # Encode prompt
    context = torch.tensor(encode(prompt), dtype=torch.long, device="cuda").unsqueeze(0)
    
    # Generate
    with torch.no_grad():
        # Use greedy generation for simplicity
        for _ in range(max_tokens):
            # Crop context to block_size
            context_crop = context[:, -config.model.block_size:]
            
            # Get energies
            h = model.backbone(context_crop)
            energies = model.energy(h[:, -1:, :])  # Last position
            logits = -energies.squeeze(1)  # (B, V)
            
            # Apply temperature
            logits = logits / temperature
            
            # Sample
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to context
            context = torch.cat([context, next_token], dim=1)
    
    # Decode
    generated = context[0].tolist()
    text = decode(generated)
    return text

# Test generation
prompt = "To be or not to be"
result = generate_text(prompt, max_tokens=100, temperature=0.8)
print(f"\n{'='*60}")
print(f"Prompt: {prompt}")
print(f"{'='*60}")
print(result)
print(f"{'='*60}")

---
## 6. Visualization & Analysis

### 6.1 Energy Landscape Visualization

In [None]:
# Visualize top-K lowest energy tokens
!python viz.py \
    --checkpoint=out_ebt/thinking/final.pt \
    --prompt="ROMEO:" \
    --mode=topk \
    --topk=20

# Display the generated plot
from IPython.display import Image, display
import matplotlib.pyplot as plt

if Path("energies_topk.png").exists():
    display(Image(filename="energies_topk.png"))

### 6.2 Energy vs Cross-Entropy Correlation

In [None]:
# Analyze correlation between energy and CE loss
!python viz.py \
    --checkpoint=out_ebt/thinking/final.pt \
    --mode=correlation \
    --batches=4

if Path("energy_vs_ce.png").exists():
    display(Image(filename="energy_vs_ce.png"))

### 6.3 Token Trajectory Analysis

In [None]:
# Visualize logit trajectories during refinement
!python viz.py \
    --checkpoint=out_ebt/thinking/final.pt \
    --prompt="HAMLET:" \
    --mode=trajectories \
    --steps=8

if Path("token_logit_trajectories.png").exists():
    display(Image(filename="token_logit_trajectories.png"))

### 6.4 3D Energy Surface

In [None]:
# Create 3D energy landscape visualization
!python viz.py \
    --checkpoint=out_ebt/thinking/final.pt \
    --prompt="The" \
    --mode=surface3d \
    --surface_grid=60 \
    --surface_span=6.0

if Path("energy_surface3d.png").exists():
    display(Image(filename="energy_surface3d.png"))

---
## 7. Metrics Inspection

### 7.1 Load and Plot Training Metrics

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def load_metrics(metrics_path):
    """Load metrics from JSONL file."""
    metrics = []
    with open(metrics_path, 'r') as f:
        for line in f:
            metrics.append(json.loads(line))
    return metrics

def plot_metrics(metrics, title="Training Metrics"):
    """Plot training metrics."""
    steps = [m['step'] for m in metrics]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(title, fontsize=16)
    
    # Loss
    axes[0, 0].plot(steps, [m['loss'] for m in metrics], label='Train Loss')
    if 'val_loss' in metrics[0]:
        val_steps = [m['step'] for m in metrics if 'val_loss' in m]
        val_losses = [m['val_loss'] for m in metrics if 'val_loss' in m]
        axes[0, 0].plot(val_steps, val_losses, label='Val Loss', marker='o')
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss over Time')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Perplexity
    axes[0, 1].plot(steps, [m['perplexity'] for m in metrics])
    axes[0, 1].set_xlabel('Step')
    axes[0, 1].set_ylabel('Perplexity')
    axes[0, 1].set_title('Perplexity over Time')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Energy gap (if available)
    if 'energy_gap' in metrics[0]:
        axes[1, 0].plot(steps, [m.get('energy_gap', 0) for m in metrics])
        axes[1, 0].set_xlabel('Step')
        axes[1, 0].set_ylabel('Energy Gap')
        axes[1, 0].set_title('Energy Gap (Initial - Final)')
        axes[1, 0].grid(True, alpha=0.3)
    
    # Learning rate
    axes[1, 1].plot(steps, [m['lr'] for m in metrics])
    axes[1, 1].set_xlabel('Step')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Load and plot baseline metrics
baseline_metrics_path = Path("out_ebt/baseline") / "metrics.jsonl"
if baseline_metrics_path.exists():
    baseline_metrics = load_metrics(baseline_metrics_path)
    fig = plot_metrics(baseline_metrics, "Baseline Model Training")
    plt.show()
else:
    print("Baseline metrics not found. Train the model first.")

In [None]:
# Load and plot thinking model metrics
thinking_metrics_path = Path("out_ebt/thinking") / "metrics.jsonl"
if thinking_metrics_path.exists():
    thinking_metrics = load_metrics(thinking_metrics_path)
    fig = plot_metrics(thinking_metrics, "Thinking Model Training")
    plt.show()
else:
    print("Thinking model metrics not found. Train the model first.")

### 7.2 Compare Baseline vs Thinking

In [None]:
# Compare both models
if baseline_metrics_path.exists() and thinking_metrics_path.exists():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss comparison
    baseline_steps = [m['step'] for m in baseline_metrics]
    baseline_losses = [m['loss'] for m in baseline_metrics]
    thinking_steps = [m['step'] for m in thinking_metrics]
    thinking_losses = [m['loss'] for m in thinking_metrics]
    
    axes[0].plot(baseline_steps, baseline_losses, label='Baseline', alpha=0.7)
    axes[0].plot(thinking_steps, thinking_losses, label='Thinking', alpha=0.7)
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss Comparison')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Perplexity comparison
    baseline_ppl = [m['perplexity'] for m in baseline_metrics]
    thinking_ppl = [m['perplexity'] for m in thinking_metrics]
    
    axes[1].plot(baseline_steps, baseline_ppl, label='Baseline', alpha=0.7)
    axes[1].plot(thinking_steps, thinking_ppl, label='Thinking', alpha=0.7)
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('Perplexity')
    axes[1].set_title('Perplexity Comparison')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    print("\n" + "="*60)
    print("FINAL METRICS COMPARISON")
    print("="*60)
    print(f"\nBaseline (Step {baseline_metrics[-1]['step']}):")
    print(f"  Loss: {baseline_metrics[-1]['loss']:.4f}")
    print(f"  Perplexity: {baseline_metrics[-1]['perplexity']:.4f}")
    
    print(f"\nThinking (Step {thinking_metrics[-1]['step']}):")
    print(f"  Loss: {thinking_metrics[-1]['loss']:.4f}")
    print(f"  Perplexity: {thinking_metrics[-1]['perplexity']:.4f}")
    if 'energy_gap' in thinking_metrics[-1]:
        print(f"  Energy Gap: {thinking_metrics[-1]['energy_gap']:.4f}")
    print("="*60)

---
## 8. Model Checkpoint Analysis

In [None]:
# Analyze checkpoint contents
def analyze_checkpoint(ckpt_path):
    """Analyze a model checkpoint."""
    ckpt = torch.load(ckpt_path, map_location='cpu')
    
    print(f"\n{'='*60}")
    print(f"Checkpoint: {ckpt_path}")
    print(f"{'='*60}")
    
    # Basic info
    print(f"\nTraining step: {ckpt.get('step', 'N/A')}")
    
    # Model info
    if 'model' in ckpt:
        state_dict = ckpt['model']
        total_params = sum(p.numel() for p in state_dict.values())
        print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
        
        # Parameter breakdown
        print("\nParameter breakdown:")
        for name, param in state_dict.items():
            if len(param.shape) > 0:  # Skip scalars
                print(f"  {name:50s} {str(tuple(param.shape)):20s} {param.numel():>10,}")
    
    # Config
    if 'config' in ckpt:
        cfg = ckpt['config']
        print(f"\nConfiguration:")
        print(f"  Vocab size: {cfg['model']['vocab_size']}")
        print(f"  Block size: {cfg['model']['block_size']}")
        print(f"  Layers: {cfg['model']['n_layer']}")
        print(f"  Heads: {cfg['model']['n_head']}")
        print(f"  Embedding dim: {cfg['model']['n_embd']}")
        print(f"  MCMC steps: {cfg['model']['mcmc_num_steps']}")
    
    print(f"{'='*60}\n")

# Analyze thinking model checkpoint
thinking_ckpt = "out_ebt/thinking/final.pt"
if Path(thinking_ckpt).exists():
    analyze_checkpoint(thinking_ckpt)

---
## 9. Export Model for Production

Prepare model for deployment.

In [None]:
# Export model to ONNX (optional)
# Note: This requires additional setup and may not work for all configurations

def export_to_onnx(checkpoint_path, output_path="model.onnx"):
    """Export model to ONNX format."""
    ckpt = torch.load(checkpoint_path, map_location="cuda")
    config = Config(**ckpt["config"])
    
    model = EBTLanguageModel(
        config.model.to_gpt_config(),
        config.model.to_ebt_config()
    )
    model.load_state_dict(ckpt["model"])
    model = model.to("cuda")
    model.eval()
    
    # Create dummy input
    dummy_input = torch.randint(0, config.model.vocab_size, (1, config.model.block_size), device="cuda")
    
    # Export
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch', 1: 'sequence'}}
    )
    
    print(f"✓ Model exported to {output_path}")

# Uncomment to export
# export_to_onnx("out_ebt/thinking/final.pt", "nanoebm_thinking.onnx")

---
## 10. Next Steps & Tips

### Performance Optimization
- Use `torch.compile()` for 2-3x speedup on H100
- Enable `bf16` training for faster computation
- Increase batch size to maximize GPU utilization
- Use gradient accumulation for larger effective batch sizes

### Hyperparameter Tuning
- **MCMC steps**: Try 2-4 steps for balance of performance/speed
- **Learning rate**: 3e-4 works well for most cases
- **Entropy tau**: 0.5-1.0 for exploration/exploitation balance
- **Step size**: Enable learnable alpha for automatic tuning

### Scaling Up
- Use BPE tokenization for production models
- Train on larger datasets (FineWeb, OpenWebText, etc.)
- Increase model size: 12 layers, 768 dim → ~100M params
- Multi-GPU training with DDP (Distributed Data Parallel)

### Monitoring
- Watch energy gap metric during training
- Monitor perplexity on validation set
- Use W&B for cloud-based experiment tracking
- Compare thinking vs baseline periodically

### Inference Optimization
- Use KV cache for faster autoregressive generation
- Reduce think_steps at inference for speed
- Experiment with top-k/top-p sampling parameters
- Consider quantization (int8/int4) for deployment