# Checkpoint & Resume - Fault Tolerance

## Why Checkpointing?

### Without Checkpoints
```
Job runs 8 hours → GPU fails at 7.5 hours → LOSE EVERYTHING ❌
Restart from scratch → Another 8 hours → 16 hours total 😢
```

### With Checkpoints (every 30 min)
```
Job runs 8 hours → GPU fails at 7.5 hours → Checkpoint at 7.0 hours ✓
Resume from 7.0 hours → 1 hour to finish → 9 hours total 🎉
Saved: 7 hours (44% time savings)
```

## Checkpoint Contents

A checkpoint saves:
- ✅ **Model weights** (state_dict)
- ✅ **Optimizer state** (momentum, lr)
- ✅ **Training step** (epoch, global_step)
- ✅ **Random state** (reproducibility)
- ✅ **Best metrics** (accuracy, loss)
- ✅ **Scheduler state** (learning rate schedule)

In [2]:
import sys
sys.path.insert(0, '..')

from src.pipeline.checkpoint_manager import CheckpointManager
from src.scheduler.job_queue import JobConfig, get_job_queue, Priority
from src.monitoring.logger import setup_logging

setup_logging(level="INFO")
print("✓ Setup complete!")

2025-10-09 21:21:15 - root - [32mINFO[0m - Logging initialized at level INFO [logger.py:202]
✓ Setup complete!


## 1. Checkpoint Configuration

In [18]:
# Configure checkpoint strategy
checkpoint_manager = CheckpointManager(
    checkpoint_dir="./checkpoints/demo_job",
    max_checkpoints=3,  # Keep last 3
    save_best_only=False  # Save all checkpoints
)

print("Checkpoint Configuration:")
print(f"  Directory: {checkpoint_manager.checkpoint_dir}")
print(f"  Max checkpoints: {checkpoint_manager.max_checkpoints}")
print(f"  Save best only: {checkpoint_manager.save_best_only}")

2025-10-09 21:30:18 - src.pipeline.checkpoint_manager - [32mINFO[0m - Checkpoint manager initialized: ./checkpoints/demo_job (max=3, best_only=False) [checkpoint_manager.py:52]
Checkpoint Configuration:
  Directory: checkpoints/demo_job
  Max checkpoints: 3
  Save best only: False


## 2. Create Checkpoint Manager

In [6]:
import os

# Initialize checkpoint manager
checkpoint_manager = CheckpointManager(
    checkpoint_dir="./checkpoints/demo_job",
    max_checkpoints=3,
    save_best_only=False
)

print(f"Checkpoint Manager Initialized")
print(f"  Checkpoint directory: {checkpoint_manager.checkpoint_dir}")

# Checkpoint directory already created by CheckpointManager.__init__
print(f"✓ Checkpoint directory ready")

2025-10-09 21:23:24 - src.pipeline.checkpoint_manager - [32mINFO[0m - Checkpoint manager initialized: ./checkpoints/demo_job (max=3, best_only=False) [checkpoint_manager.py:52]
Checkpoint Manager Initialized
  Checkpoint directory: checkpoints/demo_job
✓ Checkpoint directory ready


## 3. Saving Checkpoints

In [8]:
# Simulate training and save checkpoints
import torch

# Create dummy model and optimizer (for demo)
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Save checkpoint
checkpoint_path = checkpoint_manager.save_checkpoint(
    model=model,
    optimizer=optimizer,
    epoch=1,
    step=1000,
    metrics={"eval_loss": 0.189, "eval_accuracy": 0.92},
    train_loss=0.234,
    best_metric=0.189
)

if checkpoint_path:
    print(f"✓ Checkpoint saved: {checkpoint_path}")
    print(f"  Epoch: 1, Step: 1000")
    print(f"  Eval loss: 0.189")
    print(f"  Eval accuracy: 92%")
else:
    print("✓ Checkpoint skipped (not better than previous)")

2025-10-09 21:24:30 - src.pipeline.checkpoint_manager - [32mINFO[0m - Saving checkpoint to checkpoints/demo_job/checkpoint_epoch1_step1000 [checkpoint_manager.py:96]
2025-10-09 21:24:30 - src.pipeline.checkpoint_manager - [32mINFO[0m - New best checkpoint: checkpoints/demo_job/checkpoint_epoch1_step1000 (metric=0.1890) [checkpoint_manager.py:134]
2025-10-09 21:24:30 - src.pipeline.checkpoint_manager - [32mINFO[0m - Checkpoint saved: checkpoints/demo_job/checkpoint_epoch1_step1000 [checkpoint_manager.py:139]
✓ Checkpoint saved: checkpoints/demo_job/checkpoint_epoch1_step1000
  Epoch: 1, Step: 1000
  Eval loss: 0.189
  Eval accuracy: 92%


In [10]:
# Save multiple checkpoints (simulating training progress)
for epoch, step in enumerate([2000, 3000, 4000, 5000], start=2):
    eval_loss = 0.189 * (5000 / step)  # Decreasing loss
    
    checkpoint_path = checkpoint_manager.save_checkpoint(
        model=model,
        optimizer=optimizer,
        epoch=epoch,
        step=step,
        metrics={"eval_loss": eval_loss}
    )
    
    if checkpoint_path:
        print(f"Checkpoint saved - Epoch {epoch}, Step {step}: loss={eval_loss:.3f}")

print("\n✓ Multiple checkpoints saved")

2025-10-09 21:25:15 - src.pipeline.checkpoint_manager - [32mINFO[0m - Saving checkpoint to checkpoints/demo_job/checkpoint_epoch2_step2000 [checkpoint_manager.py:96]
2025-10-09 21:25:15 - src.pipeline.checkpoint_manager - [32mINFO[0m - Checkpoint saved: checkpoints/demo_job/checkpoint_epoch2_step2000 [checkpoint_manager.py:139]
Checkpoint saved - Epoch 2, Step 2000: loss=0.473
2025-10-09 21:25:15 - src.pipeline.checkpoint_manager - [32mINFO[0m - Saving checkpoint to checkpoints/demo_job/checkpoint_epoch3_step3000 [checkpoint_manager.py:96]
2025-10-09 21:25:15 - src.pipeline.checkpoint_manager - [32mINFO[0m - Checkpoint saved: checkpoints/demo_job/checkpoint_epoch3_step3000 [checkpoint_manager.py:139]
Checkpoint saved - Epoch 3, Step 3000: loss=0.315
2025-10-09 21:25:15 - src.pipeline.checkpoint_manager - [32mINFO[0m - Saving checkpoint to checkpoints/demo_job/checkpoint_epoch4_step4000 [checkpoint_manager.py:96]
2025-10-09 21:25:15 - src.pipeline.checkpoint_manager - [32mINF

## 4. List Available Checkpoints

In [12]:
# Get all checkpoints
checkpoints = checkpoint_manager.list_checkpoints()

print(f"Available Checkpoints ({len(checkpoints)}):")
for ckpt_path in checkpoints:
    print(f"  {ckpt_path}")

# Get best checkpoint path
best_path = checkpoint_manager.get_best_checkpoint_path()
if best_path:
    print(f"\nBest checkpoint: {best_path}")
else:
    print("\nNo best checkpoint tracked")

# Check if checkpoints exist
has_ckpt = checkpoint_manager.has_checkpoint()
print(f"Has checkpoints: {has_ckpt}")

Available Checkpoints (3):
  {'path': 'checkpoints/demo_job/checkpoint_epoch3_step3000', 'epoch': 3, 'step': 3000, 'metrics': {'eval_loss': 0.315}}
  {'path': 'checkpoints/demo_job/checkpoint_epoch4_step4000', 'epoch': 4, 'step': 4000, 'metrics': {'eval_loss': 0.23625000000000002}}
  {'path': 'checkpoints/demo_job/checkpoint_epoch5_step5000', 'epoch': 5, 'step': 5000, 'metrics': {'eval_loss': 0.189}}

Best checkpoint: checkpoints/demo_job/checkpoint_epoch1_step1000
Has checkpoints: True


## 5. Resume from Checkpoint

In [14]:
# Load checkpoint to resume training
print("Simulating job crash and resume...\n")

# Create fresh model and optimizer (simulating restart)
new_model = torch.nn.Linear(10, 2)
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)

print("Before resume:")
print(f"  Model state: random initialization")
print(f"  Global step: 0")
print()

# Load checkpoint (pass model and optimizer to load state into them)
metadata = checkpoint_manager.load_checkpoint(
    model=new_model,
    optimizer=new_optimizer,
    load_best=False  # Load latest checkpoint
)

# Get restored state from metadata
start_step = metadata.get("step", 0)
start_epoch = metadata.get("epoch", 0)

print("After resume:")
print(f"  Model state: restored from checkpoint")
print(f"  Global step: {start_step}")
print(f"  Epoch: {start_epoch}")
print(f"  Training continues from step {start_step + 1}")
print()
print("✓ Training resumed successfully!")

Simulating job crash and resume...

Before resume:
  Model state: random initialization
  Global step: 0

2025-10-09 21:27:36 - src.pipeline.checkpoint_manager - [32mINFO[0m - Loading checkpoint from checkpoints/demo_job/checkpoint_epoch5_step5000 [checkpoint_manager.py:186]
2025-10-09 21:27:36 - src.pipeline.checkpoint_manager - [32mINFO[0m - Checkpoint loaded: epoch=5, step=5000 [checkpoint_manager.py:212]
After resume:
  Model state: restored from checkpoint
  Global step: 5000
  Epoch: 5
  Training continues from step 5001

✓ Training resumed successfully!


## 6. Job Configuration with Checkpointing

In [15]:
# Job config with checkpoint settings
job_with_checkpoints = JobConfig(
    job_id="checkpoint-demo-001",
    user_id="demo-user",
    job_type="fine_tuning",
    
    pool_type="production",
    num_gpus=1,
    is_preemptible=False,
    
    model_name="bert-base-uncased",
    dataset_path="./data/train.csv",
    output_dir="./output/checkpoint_demo",
    
    priority="MEDIUM",
    estimated_duration=3600,
    
    config={
        # Checkpointing
        "checkpoint_dir": "./checkpoints/checkpoint_demo",
        "save_steps": 500,  # Every 500 steps
        "save_total_limit": 3,  # Keep 3 checkpoints
        "save_strategy": "steps",  # "steps" or "epoch"
        
        # Resume settings
        "resume_from_checkpoint": True,  # Auto-resume if checkpoint exists
        "ignore_data_skip": False,  # Skip already processed data
        
        # Evaluation
        "evaluation_strategy": "steps",
        "eval_steps": 500,
        "load_best_model_at_end": True,  # Load best checkpoint when done
        "metric_for_best_model": "eval_loss",
        
        # Training
        "num_train_epochs": 3,
        "per_device_train_batch_size": 16,
    }
)

print("Job Configuration with Checkpointing:")
print(f"  Save checkpoints every: {job_with_checkpoints.config['save_steps']} steps")
print(f"  Keep checkpoints: {job_with_checkpoints.config['save_total_limit']}")
print(f"  Auto-resume: {job_with_checkpoints.config['resume_from_checkpoint']}")
print(f"  Load best at end: {job_with_checkpoints.config['load_best_model_at_end']}")

Job Configuration with Checkpointing:
  Save checkpoints every: 500 steps
  Keep checkpoints: 3
  Auto-resume: True
  Load best at end: True


## 7. Checkpoint Management

In [17]:
from pathlib import Path

# Cleanup old checkpoints
print("Checkpoint Management:\n")

# List checkpoints
checkpoints = checkpoint_manager.list_checkpoints()
print(f"Total checkpoints: {len(checkpoints)}")

# Cleanup is automatic (called by save_checkpoint)
# But you can check max_checkpoints setting
print(f"Max checkpoints setting: {checkpoint_manager.max_checkpoints}")
print(f"Remaining checkpoints: {len(checkpoints)}\n")

# Calculate storage used
total_size = 0
for ckpt in checkpoints:
    ckpt_path = Path(ckpt['path'])
    if ckpt_path.exists():
        # Calculate directory size
        total_size += sum(f.stat().st_size for f in ckpt_path.rglob('*') if f.is_file())

print(f"Storage used: {total_size / (1024**2):.2f} MB")

Checkpoint Management:

Total checkpoints: 3
Max checkpoints setting: 3
Remaining checkpoints: 3

Storage used: 0.01 MB
