# Train GPT on Knights and Knaves Puzzles

This notebook trains a GPT model on 100M Knights and Knaves puzzles.
Designed for 8xB200 GPUs but can be adapted for other configurations.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import os
import sys
import math
import time
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

# Add the current directory to the path
sys.path.append(os.path.dirname(os.path.abspath('__file__')))

from data.knights_knaves import get
from mingpt.dataset import CharDataset
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig
from mingpt.utils import set_seed, sample

In [None]:
# Set random seed for reproducibility
set_seed(42)

## Load Dataset

We'll start with a smaller subset for testing, then scale up to the full 100M dataset.

In [None]:
# Load Knights and Knaves dataset
# For initial testing, use max_games=100000 (100k puzzles)
# For full training, set max_games=None to use all 100M puzzles

knights_knaves = get(
    data_path="data/n_2.jsonl",
    max_games=100000,  # Start with 100k for testing
    val_split=0.1,     # 10% validation for small dataset, use 0.01 for full dataset
    seed=42
)

In [None]:
# Create character dataset wrapper
train_dataset = CharDataset(knights_knaves)
print(f"Vocabulary size: {train_dataset.vocab_size}")
print(f"Block size: {train_dataset.block_size}")
print(f"Number of training sequences: {len(train_dataset)}")

In [None]:
# Look at a sample puzzle
sample_idx = 0
sample_puzzle = knights_knaves.train[sample_idx]
print("Sample puzzle (as characters):")
print(''.join([c for c in sample_puzzle if c != -100]))
print("\nTokenized:")
print([train_dataset.stoi.get(c, 0) for c in sample_puzzle[:50]])

## Model Configuration

In [None]:
# Model configuration
# For 100M dataset on 8xB200 GPUs, we can use a larger model
mconf = GPTConfig(
    vocab_size=train_dataset.vocab_size,
    block_size=train_dataset.block_size,
    n_layer=12,      # 12 layers for larger dataset
    n_head=12,       # 12 attention heads
    n_embd=768,      # 768 embedding dimensions
    embd_pdrop=0.1,
    resid_pdrop=0.1,
    attn_pdrop=0.1,
)

model = GPT(mconf)

# Print model size
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"Model size: {total_params * 4 / 1024**3:.2f} GB (assuming float32)")

## Training Configuration

In [None]:
# Training configuration
# For 8xB200 GPUs
batch_size_per_gpu = 1024
num_gpus = torch.cuda.device_count()
total_batch_size = batch_size_per_gpu * num_gpus

print(f"Number of GPUs available: {num_gpus}")
print(f"Batch size per GPU: {batch_size_per_gpu}")
print(f"Total batch size: {total_batch_size}")

# For single GPU testing, use smaller batch size
if num_gpus == 1:
    total_batch_size = 64
    print(f"Using single GPU with batch size: {total_batch_size}")

In [None]:
# Initialize trainer
max_epochs = 10
t_start = time.strftime("_%Y%m%d_%H%M%S")

tconf = TrainerConfig(
    max_epochs=max_epochs,
    batch_size=total_batch_size,
    learning_rate=6e-4,
    lr_decay=True,
    warmup_tokens=len(train_dataset) * train_dataset.block_size * 0.1,  # 10% warmup
    final_tokens=len(train_dataset) * train_dataset.block_size * max_epochs,
    num_workers=4,
    ckpt_path=f"./ckpts/gpt_knights_knaves{t_start}.ckpt",
)

trainer = Trainer(model, train_dataset, None, tconf)
device = trainer.device
print(f"Training on device: {device}")
print(f"Checkpoint will be saved to: {tconf.ckpt_path}")

## Train the Model

In [None]:
# Train the model
print("Starting training...")
trainer.train()

## Load Trained Model (Optional)

In [None]:
# Load a trained model from checkpoint
# checkpoint_path = "./ckpts/gpt_knights_knaves_YYYYMMDD_HHMMSS.ckpt"
# model.load_state_dict(torch.load(checkpoint_path))
# if torch.cuda.is_available():
#     model = model.to('cuda')

## Validation

Test the model's ability to solve Knights and Knaves puzzles.

In [None]:
def generate_solution(model, puzzle_text, device, temperature=0.1):
    """Generate a solution for a Knights and Knaves puzzle."""
    model.eval()
    
    # Add the separator
    context = puzzle_text + " => "
    
    # Encode the context
    x = torch.tensor([train_dataset.stoi.get(s, 0) for s in context], dtype=torch.long)[None, ...].to(device)
    
    # Generate prediction (expecting 2-3 characters for solution like "KN" or "NNK")
    with torch.no_grad():
        y = sample(model, x, 10, temperature=temperature)[0]
    
    # Decode the prediction
    completion = [train_dataset.itos[int(i)] for i in y if i != -1]
    
    # Extract just the solution part (K/N characters)
    solution = ''.join([c for c in completion if c in ['K', 'N']])
    
    return solution

In [None]:
# Validate on a sample of puzzles
model.eval()
num_test = min(100, len(knights_knaves.val))
correct = 0

for i in tqdm(range(num_test), desc="Validating"):
    # Get validation puzzle
    puzzle_chars = knights_knaves.val[i]
    puzzle_text = ''.join([c for c in puzzle_chars if c != -100])
    
    # Split puzzle and solution
    if " => " in puzzle_text:
        puzzle_part, actual_solution = puzzle_text.split(" => ")
        
        # Generate prediction
        predicted_solution = generate_solution(model, puzzle_part, device)
        
        # Compare
        if predicted_solution == actual_solution:
            correct += 1
        
        # Show first few examples
        if i < 5:
            print(f"\nPuzzle {i}:")
            print(f"Input: {puzzle_part[:100]}..." if len(puzzle_part) > 100 else f"Input: {puzzle_part}")
            print(f"Actual: {actual_solution}")
            print(f"Predicted: {predicted_solution}")
            print(f"Correct: {predicted_solution == actual_solution}")

accuracy = correct / num_test
print(f"\nValidation accuracy: {correct}/{num_test} = {accuracy:.2%}")

## Interactive Testing

In [None]:
# Test with a custom puzzle
test_puzzle = "says 0 (isKnight 1), says 1 (isKnave 0)"
solution = generate_solution(model, test_puzzle, device, temperature=0.1)
print(f"Puzzle: {test_puzzle}")
print(f"Predicted solution: {solution}")
print("\nInterpretation:")
for i, s in enumerate(solution):
    print(f"Agent {i}: {'Knight' if s == 'K' else 'Knave'}")

## Save Final Model

In [None]:
# Save the final trained model
final_checkpoint_path = f"./ckpts/gpt_knights_knaves_final{t_start}.ckpt"
torch.save(model.state_dict(), final_checkpoint_path)
print(f"Model saved to: {final_checkpoint_path}")