## 02 · Save Full Checkpoint with Extra State  

To support fault-tolerant recovery, we extend checkpoint saving to include not just the model, but also the **optimizer state** and the **current epoch**.  

- **`model.pt`** → model weights (unwrap DDP with `.module`).  
- **`optimizer.pt`** → optimizer state for resuming training seamlessly.  
- **`extra_state.pt`** → stores metadata (here, the current epoch).  

Only the **rank-0 worker** writes the checkpoint to avoid duplication, but all workers still call `ray.train.report()` to keep the loop synchronized.  

This ensures that if training is interrupted, Ray Train can restore **model weights, optimizer progress, and the correct epoch** before continuing.  

In [None]:
# 02. Save checkpoint with model, optimizer, and epoch state

def save_checkpoint_and_metrics_ray_train_with_extra_state(
    model: torch.nn.Module,
    metrics: dict[str, float],
    optimizer: torch.optim.Optimizer,
    epoch: int,
) -> None:

    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None
        # Only rank-0 worker saves files to disk
        if ray.train.get_context().get_world_rank() == 0:
                # Save all state required for full recovery
                torch.save(
                    model.module.state_dict(),  # unwrap DDP before saving
                    os.path.join(temp_checkpoint_dir, "model.pt"),
                )
                torch.save(
                    optimizer.state_dict(),     # include optimizer state
                    os.path.join(temp_checkpoint_dir, "optimizer.pt"),
                )
                torch.save(
                    {"epoch": epoch},           # store last completed epoch
                    os.path.join(temp_checkpoint_dir, "extra_state.pt"),
                )
                # Package into a Ray checkpoint
                checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
        
        # Report metrics and attach checkpoint (only rank-0 attaches checkpoint)
        ray.train.report(  
            metrics,  
            checkpoint=checkpoint,
            )    

### 03 · Configure Automatic Retries with `FailureConfig`  

Now that the training loop can load from checkpoints, we can enable **automatic retries** in case of worker or node failures.  

- **`FailureConfig(max_failures=3)`** → allows the job to retry up to 3 times before giving up.  
- Pass this `failure_config` into `RunConfig` so Ray Train knows how to handle failures.  
- When a failure happens, Ray will:  
  1. Restart the failed workers.  
  2. Reload the latest checkpoint.  
  3. Resume training from the last saved epoch.  

This setup makes training jobs resilient to transient hardware or cluster issues without requiring manual intervention.  

In [None]:
# 03. Configure TorchTrainer with fault-tolerance enabled

# Allow up to 3 automatic retries if workers fail
failure_config = ray.train.FailureConfig(max_failures=3)

experiment_name = "fault-tolerant-cifar-vit"

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading,  # fault-tolerant loop
    train_loop_config={   # hyperparameters
        "num_epochs": 1,
        "global_batch_size": 512,
    },
    scaling_config=scaling_config,  # resource scaling as before
    run_config=ray.train.RunConfig(
        name="fault-tolerant-cifar-vit",
        storage_path=storage_path,      # persistent checkpoint storage
        failure_config=failure_config,  # enable automatic retries
    ),
    datasets=datasets,  # Ray Dataset shard for each worker
)