# Test02: TrainingRun Mid-Epoch Save/Resume

This notebook tests the TrainingRun save/load functionality:
1. Train a small GPTv1 model for a few batches
2. Save checkpoint mid-epoch
3. Resume training from exact position
4. Create FinalModel for deployment

In [1]:
import sys
import os
import torch
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

print(f"Root: {ROOT_DIR}")
print(f"Experiment: {EXPERIMENT_DIR}")

Root: /Users/sriman/dev/apps/toy-transformers
Experiment: /Users/sriman/dev/apps/toy-transformers/experiments/test02


In [None]:
from toy_transformers.models import gptv1
from toy_transformers.training.training_run import TrainingRun
from toy_transformers.training.optimizer import OptimizerConfig, AdamWConfig, NoSchedulerConfig
from toy_transformers.utilities import io
from toy_transformers.data import tokenization
from toy_transformers.utilities.reproducibility import set_all_seeds

print("✓ Imports successful")

✓ Imports successful


In [3]:
# Load dataset from run01
vocab_path = ROOT_DIR / "experiments/test02/artifacts/vocab256"
data_path = ROOT_DIR / "experiments/test02/artifacts/data"

if not vocab_path.exists() or not data_path.exists():
  tokenization.tokenize_from_scratch(
  	"data/input.txt",
  	vocab_path,
  	data_path,
  	500,
  	verbose=True
	)

train_data = io.load(str(data_path))
train_dataset = tokenization.TokenizedData.from_state_dict(train_data)
print(f"Dataset loaded: {len(train_dataset.data):,} tokens")

counting word frequencies...


Counting words: 100%|██████████| 1.12M/1.12M [00:00<00:00, 15.8MB/s]


found 15294 unique words
base vocabulary size: 59


merging tokens: 100%|██████████| 500/500 [00:00<00:00, 4213.52it/s]


Dataset loaded: 433,344 tokens


In [4]:
# Set reproducibility
set_all_seeds(42, deterministic=True)

# Small model for testing
config = gptv1.GPTv1Config(
  batch_size=4,
  block_size=32,
  n_heads=2,
  n_embed=64,
  n_layers=2,
  dropout=0.1,
  device="mps" if torch.backends.mps.is_available() else "cpu"
)
vocab_size = 256

print(f"Config: batch_size={config.batch_size}, block_size={config.block_size}")
print(f"Architecture: n_layers={config.n_layers}, n_embed={config.n_embed}, n_heads={config.n_heads}")
print(f"Device: {config.device}")

Config: batch_size=4, block_size=32
Architecture: n_layers=2, n_embed=64, n_heads=2
Device: mps


In [5]:
# Create training run
training_run = TrainingRun(
	model_class=gptv1.LanguageModel,
	config_class=gptv1.GPTv1Config,
	model_config=config,
	optimizer_config=OptimizerConfig(
		optimizer_type="adamw",
		optimizer_params=AdamWConfig(lr=1e-3, weight_decay=0.01),
		scheduler=NoSchedulerConfig()
	),
	base_seed=42,
	dataset=train_dataset,
	block_size=config.block_size,
	batch_size=config.batch_size,
	vocab_size=vocab_size
)

model = training_run.create_model().to(config.device)
optimizer = training_run.create_optimizer(model)
num_params = sum(p.numel() for p in model.parameters())

print(f"TrainingRun created")
print(f"Dataset hash: {training_run.dataset_hash}")
print(f"Model parameters: {num_params:,}")

TrainingRun created
Dataset hash: -2075907010083267406
Model parameters: 134,784


In [6]:
# Train for a few batches
print("Training initial batches...")
save_after_batches = 2
checkpoint_path = EXPERIMENT_DIR / "artifacts/checkpoints"

training_run.epoch = 0
dataloader = training_run.create_dataloader(train_dataset, epoch=0)

for batch_idx, (x, y) in enumerate(dataloader):
    x, y = x.to(config.device), y.to(config.device)
    
    logits, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    training_run.step += 1
    training_run.batches_completed += 1
    training_run.log_step(train_loss=loss.item(), lr=optimizer.get_lr())
    
    print(f"Batch {batch_idx}, Step {training_run.step}, Loss: {loss.item():.4f}")
    
    if batch_idx == save_after_batches:
        print(f"\n⚡ Saving checkpoint at batch {batch_idx}...")
        io.save(
            training_run.to_state_dict(model, optimizer),
            str(checkpoint_path)
        )
        print(f"✓ Checkpoint saved")
        print(f"State: epoch={training_run.epoch}, step={training_run.step}, batches_completed={training_run.batches_completed}")
        break
    
    if batch_idx >= 9:
        break

Training initial batches...
Batch 0, Step 1, Loss: 3.8792
Batch 1, Step 2, Loss: 4.3184
Batch 2, Step 3, Loss: 4.2239

⚡ Saving checkpoint at batch 2...
✓ Checkpoint saved
State: epoch=0, step=3, batches_completed=3


In [7]:
# Resume from checkpoint
print("\n" + "="*60)
print("Testing Resume from Checkpoint")
print("="*60 + "\n")

loaded = io.load(str(checkpoint_path))
resumed_run, model_state, opt_state = TrainingRun.from_state_dict(loaded)
print("✓ Checkpoint loaded")

# Verify dataset
resumed_run.verify_dataset(train_dataset)
print(f"✓ Dataset verified (hash: {resumed_run.dataset_hash})")
print(f"Resuming from: epoch={resumed_run.epoch}, step={resumed_run.step}, batches_completed={resumed_run.batches_completed}")

# Restore model and optimizer
model = resumed_run.create_model().to(config.device)
model.load_state_dict(model_state)
optimizer = resumed_run.create_optimizer(model)
optimizer.load_state_dict(opt_state)
print("✓ Model and optimizer restored")


Testing Resume from Checkpoint

✓ Checkpoint loaded
✓ Dataset verified (hash: -2075907010083267406)
Resuming from: epoch=0, step=3, batches_completed=3
✓ Model and optimizer restored


In [8]:
# Continue training
print(f"\nResuming training from batch {resumed_run.batches_completed}...")

dataloader = resumed_run.create_dataloader(train_dataset, resumed_run.epoch)
for batch_idx, (x, y) in enumerate(dataloader, start=resumed_run.batches_completed):
    x, y = x.to(config.device), y.to(config.device)
    
    logits, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    resumed_run.step += 1
    resumed_run.batches_completed += 1
    resumed_run.log_step(train_loss=loss.item(), lr=optimizer.get_lr())
    
    print(f"Resumed - Batch {batch_idx}, Step {resumed_run.step}, Loss: {loss.item():.4f}")
    
    if batch_idx >= resumed_run.batches_completed + 5:
        print("\nStopping after a few resumed batches")
        break


Resuming training from batch 3...
Resumed - Batch 3, Step 4, Loss: 4.1431
Resumed - Batch 4, Step 5, Loss: 4.6351
Resumed - Batch 5, Step 6, Loss: 4.0246
Resumed - Batch 6, Step 7, Loss: 3.9792
Resumed - Batch 7, Step 8, Loss: 4.1612
Resumed - Batch 8, Step 9, Loss: 4.2653
Resumed - Batch 9, Step 10, Loss: 4.3622
Resumed - Batch 10, Step 11, Loss: 3.9521
Resumed - Batch 11, Step 12, Loss: 4.1943
Resumed - Batch 12, Step 13, Loss: 4.0568
Resumed - Batch 13, Step 14, Loss: 4.1620
Resumed - Batch 14, Step 15, Loss: 4.1277
Resumed - Batch 15, Step 16, Loss: 3.9290
Resumed - Batch 16, Step 17, Loss: 4.0221
Resumed - Batch 17, Step 18, Loss: 3.9685
Resumed - Batch 18, Step 19, Loss: 3.9574
Resumed - Batch 19, Step 20, Loss: 3.6438
Resumed - Batch 20, Step 21, Loss: 4.0119
Resumed - Batch 21, Step 22, Loss: 4.1739
Resumed - Batch 22, Step 23, Loss: 4.1003
Resumed - Batch 23, Step 24, Loss: 4.0907
Resumed - Batch 24, Step 25, Loss: 3.7849
Resumed - Batch 25, Step 26, Loss: 4.3483
Resumed - Ba

In [9]:
# Verify dataloader resumption produces identical batches
print("\n" + "="*60)
print("Verifying DataLoader Resumption Correctness")
print("="*60 + "\n")

# Create two dataloaders:
# 1. Fresh dataloader starting from batch 0
# 2. Resumed dataloader that skips to batches_completed

# Reset to the saved checkpoint state
checkpoint_epoch = 0
checkpoint_batches = 3

# Create fresh dataloader from start
from toy_transformers.data.dataset import create_dataloader
epoch_seed = resumed_run.get_epoch_seed(checkpoint_epoch)
fresh_loader = create_dataloader(
    dataset=train_dataset,
    block_size=resumed_run.block_size,
    batch_size=resumed_run.batch_size,
    shuffle=True,
    seed=epoch_seed,
    num_workers=0,
    pin_memory=False,
    drop_last=True,
    batches_completed=0
)

# Manually skip to checkpoint position
fresh_iter = iter(fresh_loader)
for _ in range(checkpoint_batches):
    next(fresh_iter)

# Get next batch from fresh (manually skipped) loader
fresh_x, fresh_y = next(fresh_iter)

# Create resumed dataloader (automatically skips via create_dataloader)
test_run = TrainingRun(
    model_class=gptv1.LanguageModel,
    config_class=gptv1.GPTv1Config,
    model_config=config,
    optimizer_config=resumed_run.optimizer_config,
    base_seed=resumed_run.base_seed,
    dataset=train_dataset,
    block_size=resumed_run.block_size,
    batch_size=resumed_run.batch_size,
    vocab_size=vocab_size,
    epoch=checkpoint_epoch,
    batches_completed=checkpoint_batches
)
resumed_loader = test_run.create_dataloader(train_dataset, checkpoint_epoch)
resumed_x, resumed_y = next(resumed_loader)

# Compare batches
print(f"Fresh batch shape: {fresh_x.shape}, {fresh_y.shape}")
print(f"Resumed batch shape: {resumed_x.shape}, {resumed_y.shape}")
print(f"\nBatches identical (x): {torch.equal(fresh_x, resumed_x)}")
print(f"Batches identical (y): {torch.equal(fresh_y, resumed_y)}")

if torch.equal(fresh_x, resumed_x) and torch.equal(fresh_y, resumed_y):
    print("\n✓ DataLoader resumption is correct!")
    print(f"  Both dataloaders produce identical batch #{checkpoint_batches + 1}")
else:
    print("\n✗ ERROR: Batches differ!")
    print(f"  Max difference (x): {torch.abs(fresh_x - resumed_x).max()}")
    print(f"  Max difference (y): {torch.abs(fresh_y - resumed_y).max()}")


Verifying DataLoader Resumption Correctness

Fresh batch shape: torch.Size([4, 32]), torch.Size([4, 32])
Resumed batch shape: torch.Size([4, 32]), torch.Size([4, 32])

Batches identical (x): True
Batches identical (y): True

✓ DataLoader resumption is correct!
  Both dataloaders produce identical batch #4


In [10]:
# Create FinalModel
print("\nCreating FinalModel...")
final_model = FinalModel.from_training_run(resumed_run, include_logs=True)
final_model_path = EXPERIMENT_DIR / "artifacts/final_model"
io.save(final_model.to_state_dict(model), str(final_model_path))

print(f"✓ FinalModel saved to {final_model_path}")
print(f"Final state: epoch={final_model.final_epoch}, step={final_model.final_step}")
print(f"Logs: {len(final_model.logs)} entries")


Creating FinalModel...
✓ FinalModel saved to /Users/sriman/dev/apps/toy-transformers/experiments/test02/artifacts/final_model
Final state: epoch=0, step=3385
Logs: 3385 entries


In [None]:
# Generate text from the trained model
print("\n" + "="*60)
print("Text Generation Demo")
print("="*60 + "\n")

# Load vocabulary for decoding
vocab = io.load(str(vocab_path))
from toy_transformers.data.bpe import Vocabulary
vocab_obj = Vocabulary.from_state_dict(vocab)

# Create seed prompt
seed_text = "The "
seed_tokens = vocab_obj.encode(seed_text)
idx = torch.tensor([seed_tokens], dtype=torch.long, device=config.device)

print(f"Seed: '{seed_text}'")
print(f"Generating {300} tokens...\n")

# Set model to eval mode
model.eval()

# Generate text
print(seed_text, end="", flush=True)
with torch.no_grad():
  for token in model.generate(idx, max_new_tokens=300):
    print(vocab_obj.decode([token.item()])[0], end="")


Text Generation Demo

Seed: 'The '
Generating 300 tokens...

The ms of done fight

VIRGIANay halfuseed sur good matteraitchet sender will o'ereys but the fear mash and If an helagest aubribil'd with sger mings o the fears pe The pet natsMet incesShall's tre least not Tru matlyThey slikesor and fait o thy pats newouldsw ourswont a moks-fter it not so cury's fries

MARCOMIN'ThesMINO held shall is cou strailewer is it serve to growsu Romeuard war the so val is a purmon RomansMore-houseed not beid a vowleatords are the fERIrovre is a general pattle businessAnd I cove FIDes I repin hewer

MENENENENENENENENENct a mark for remengeMoreld and cir clublyBut is't
✓ Text generation complete!
