## 🛡️ 03 · Fault Tolerance in Ray Train  
In this module you’ll learn how **Ray Train** handles failures and how to make your training jobs **resilient** with checkpointing and recovery. You’ll see both **automatic retries** and **manual restoration**, and how to modify the training loop so it can safely resume from the latest checkpoint.  

### What you’ll learn & take away  
* How Ray Train uses **automatic retries** to restart failed workers without losing progress  
* How to modify the training loop with **`get_checkpoint()`** to enable checkpoint loading  
* How to save additional state (e.g., optimizer and epoch) alongside the model for full recovery  
* How to configure **`FailureConfig`** to set retry behavior  
* How to perform a **manual restoration** if retries are exhausted, resuming training from the last checkpoint  
* Why checkpointing to **persistent storage** is essential for reliable recovery  

> With fault tolerance enabled, you can run long, large-scale training jobs confidently — knowing they can recover from failures without starting over.  

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-summit/stable-diffusion/diagrams/fault_tolerant_cropped_v2.png" width=800>


## 01 · Modify Training Loop to Enable Checkpoint Loading  

To support fault tolerance, we extend the training loop so it can **resume from a previously saved checkpoint**.  

Key additions:  
- Call `ray.train.get_checkpoint()` to check if a checkpoint is available.  
- If found, restore:  
  * The **model state** (`model.pt`)  
  * The **optimizer state** (`optimizer.pt`)  
  * The **last completed epoch** (`extra_state.pt`)  
- Update `start_epoch` so training resumes from the correct place.  

The rest of the loop (forward pass, backward pass, optimizer step, and metrics reporting) is the same, except it now starts from `start_epoch` instead of 0.  

In [None]:
# 01. Training loop with checkpoint loading for fault tolerance

def train_loop_ray_train_with_checkpoint_loading(config: dict):
    # Same setup as before: loss, model, optimizer
    criterion = CrossEntropyLoss()
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-3)

    # Same data loader logic as before
    global_batch_size = config["global_batch_size"]
    batch_size = global_batch_size // ray.train.get_context().get_world_size()
    data_loader = build_data_loader_ray_train_ray_data(batch_size=batch_size)

    # Default: start at epoch 0 unless a checkpoint is available
    start_epoch = 0

    # Attempt to load from latest checkpoint
    checkpoint = ray.train.get_checkpoint()
    if checkpoint:
        # Continue training from a previous checkpoint
        with checkpoint.as_directory() as ckpt_dir:
            # Restore model + optimizer state
            model_state_dict = torch.load(
                os.path.join(ckpt_dir, "model.pt"),
            )
            # Load the model and optimizer state
            model.module.load_state_dict(model_state_dict)
            optimizer.load_state_dict(
                torch.load(os.path.join(ckpt_dir, "optimizer.pt"))
            )

            # Resume from last epoch + 1
            start_epoch = (
                torch.load(os.path.join(ckpt_dir, "extra_state.pt"))["epoch"] + 1
            )

    # Same training loop as before except it starts at a parameterized start_epoch
    for epoch in range(start_epoch, config["num_epochs"]):
        for batch in data_loader:
            outputs = model(batch["image"])
            loss = criterion(outputs, batch["label"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Report metrics and save model + optimizer + epoch state
        metrics = print_metrics_ray_train(loss,  epoch)

        # We now save the optimizer and epoch state in addition to the model
        save_checkpoint_and_metrics_ray_train_with_extra_state(
            model, metrics, optimizer, epoch
        )

## 02 · Save Full Checkpoint with Extra State  

To support fault-tolerant recovery, we extend checkpoint saving to include not just the model, but also the **optimizer state** and the **current epoch**.  

- **`model.pt`** → model weights (unwrap DDP with `.module`).  
- **`optimizer.pt`** → optimizer state for resuming training seamlessly.  
- **`extra_state.pt`** → stores metadata (here, the current epoch).  

Only the **rank-0 worker** writes the checkpoint to avoid duplication, but all workers still call `ray.train.report()` to keep the loop synchronized.  

This ensures that if training is interrupted, Ray Train can restore **model weights, optimizer progress, and the correct epoch** before continuing.  

In [None]:
# 02. Save checkpoint with model, optimizer, and epoch state

def save_checkpoint_and_metrics_ray_train_with_extra_state(
    model: torch.nn.Module,
    metrics: dict[str, float],
    optimizer: torch.optim.Optimizer,
    epoch: int,
) -> None:

    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None
        # Only rank-0 worker saves files to disk
        if ray.train.get_context().get_world_rank() == 0:
                # Save all state required for full recovery
                torch.save(
                    model.module.state_dict(),  # unwrap DDP before saving
                    os.path.join(temp_checkpoint_dir, "model.pt"),
                )
                torch.save(
                    optimizer.state_dict(),     # include optimizer state
                    os.path.join(temp_checkpoint_dir, "optimizer.pt"),
                )
                torch.save(
                    {"epoch": epoch},           # store last completed epoch
                    os.path.join(temp_checkpoint_dir, "extra_state.pt"),
                )
                # Package into a Ray checkpoint
                checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
        
        # Report metrics and attach checkpoint (only rank-0 attaches checkpoint)
        ray.train.report(  
            metrics,  
            checkpoint=checkpoint,
            )    

### 03 · Configure Automatic Retries with `FailureConfig`  

Now that the training loop can load from checkpoints, we can enable **automatic retries** in case of worker or node failures.  

- **`FailureConfig(max_failures=3)`** → allows the job to retry up to 3 times before giving up.  
- Pass this `failure_config` into `RunConfig` so Ray Train knows how to handle failures.  
- When a failure happens, Ray will:  
  1. Restart the failed workers.  
  2. Reload the latest checkpoint.  
  3. Resume training from the last saved epoch.  

This setup makes training jobs resilient to transient hardware or cluster issues without requiring manual intervention.  

In [None]:
# 03. Configure TorchTrainer with fault-tolerance enabled

# Allow up to 3 automatic retries if workers fail
failure_config = ray.train.FailureConfig(max_failures=3)

experiment_name = "fault-tolerant-cifar-vit"

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading,  # fault-tolerant loop
    train_loop_config={   # hyperparameters
        "num_epochs": 1,
        "global_batch_size": 512,
    },
    scaling_config=scaling_config,  # resource scaling as before
    run_config=ray.train.RunConfig(
        name="fault-tolerant-cifar-vit",
        storage_path=storage_path,      # persistent checkpoint storage
        failure_config=failure_config,  # enable automatic retries
    ),
    datasets=datasets,  # Ray Dataset shard for each worker
)

## 04 · Launch Fault-Tolerant Training  

Finally, call `trainer.fit()` to start the training job.  

With the **fault-tolerant loop** and **`FailureConfig`** in place, Ray Train will:  
- Run the training loop on all workers.  
- If a failure occurs (e.g., worker crash, node preemption), automatically restart workers.  
- Reload the latest checkpoint and continue training without losing progress.  

This makes your training job robust against transient infrastructure failures.  

In [None]:
# 04. Start the fault-tolerant training job

# Launches training with checkpointing + automatic retries enabled
# If workers fail, Ray will reload the latest checkpoint and resume
trainer.fit()

## 05 · Manual Restoration from Checkpoints  

If the maximum number of retries is reached, you can still **manually restore training** by creating a new `TorchTrainer` with the same configuration:  

- Use the same `train_loop_ray_train_with_checkpoint_loading` so the loop can resume from a checkpoint.  
- Provide the same `run_config` (name, storage path, and failure config).  
- Pass in the same dataset and scaling configuration.  

Ray Train will detect the latest checkpoint in the specified `storage_path` and resume training from that point.  

In [None]:
# 05. Manually restore a trainer from the last checkpoint

restored_trainer = TorchTrainer(
    train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading,  # loop supports checkpoint loading
        train_loop_config={   # hyperparameters must match
        "num_epochs": 1,
        "global_batch_size": 512,
    },
    scaling_config=scaling_config,  # same resource setup as before
    run_config=ray.train.RunConfig(
        name="fault-tolerant-cifar-vit",  # must match previous run name
        storage_path=storage_path,       # path where checkpoints are saved
        failure_config=failure_config,   # still allow retries
    ),
    datasets=datasets,  # same dataset as before
)

### 06 · Resume Training from the Last Checkpoint  

Calling `restored_trainer.fit()` will continue training from the most recent checkpoint found in the specified storage path.  

- If all epochs were already completed in the previous run, the trainer will terminate immediately.  
- If training was interrupted mid-run, it will resume from the saved epoch, restoring both the **model** and **optimizer** state.  
- The returned `Result` object confirms that training picked up correctly and contains metrics, checkpoints, and logs.  

In [None]:
# 06. Resume training from the last checkpoint

# Fit the restored trainer → continues from last saved epoch
# If all epochs are already complete, training ends immediately
result = restored_trainer.fit()

# Display final training results (metrics, checkpoints, etc.)
result

## 07 · Clean Up Cluster Storage  

Finally, remove any tutorial artifacts from **persistent cluster storage**:  

- Deletes the **downloaded MNIST dataset** (`/mnt/cluster_storage/MNIST`).  
- Deletes the **training outputs** (`/mnt/cluster_storage/training`).  
- Deletes the **Parquet dataset** used for Ray Data (`/mnt/cluster_storage/cifar10.parquet`).  

This keeps your shared storage clean and avoids leftover data or files from occupying space.  
Run this only when you’re sure you no longer need the data, checkpoints, or Parquet files.  

In [None]:
# 07. Cleanup Cluster Storage

# Paths to remove → include MNIST data, training outputs, and cifar10.parquet
paths_to_delete = [
    "/mnt/cluster_storage/MNIST",
    "/mnt/cluster_storage/training",
    "/mnt/cluster_storage/cifar10.parquet",
]

for path in paths_to_delete:
    if os.path.exists(path):
        # Handle directories vs. files
        if os.path.isdir(path):
            shutil.rmtree(path)       # recursively delete directory
        else:
            os.remove(path)           # delete single file
        print(f"Deleted: {path}")
    else:
        print(f"Not found: {path}")

## 🎉 Wrapping Up & Next Steps  

Nice work -- you completed a full, production-style workflow with **Ray Train on Anyscale**, then extended it with **Ray Data**, and finally added **fault tolerance**. Here’s what you accomplished across the three modules:

---

### ✅ Module 01 · Introduction to Ray Train  
- Scaled PyTorch DDP with **`TorchTrainer`** using **`ScalingConfig`** and **`RunConfig`**  
- Wrapped code for multi-GPU with **`prepare_model()`** and **`prepare_data_loader()`**  
- Reported **metrics** and saved **checkpoints** via `ray.train.report(...)` (rank-0 checkpointing best practice)  
- Inspected results from the **`Result`** object and served **GPU inference** with a Ray actor  

---

### ✅ Module 02 · Integrating Ray Train with Ray Data  
- Prepared MNIST as **Parquet** and loaded it as a **Ray Dataset**  
- Streamed batches with **`iter_torch_batches()`** and consumed dict batches in the training loop  
- Passed datasets to the trainer via **`datasets={"train": ...}`**  
- Decoupled CPU preprocessing from GPU training for **better utilization and throughput**  

---

### ✅ Module 03 · Fault Tolerance in Ray Train  
- Enabled resume-from-checkpoint using **`ray.train.get_checkpoint()`**  
- Saved full state (model, **optimizer**, **epoch**) for robust restoration  
- Configured **`FailureConfig(max_failures=...)`** for automatic retries  
- Performed **manual restoration** by re-creating a trainer with the same `RunConfig`  

---

### 🚀 Where to go next  
- **Scale up**: Increase `num_workers`, try multi-node clusters, or switch to **FSDP** via `prepare_model(parallel_strategy="fsdp")`.  
- **Input pipelines**: Add augmentations, caching, and windowed shuffles in **Ray Data**; try multi-file Parquet or lakehouse sources.  
- **Experiment tracking**: Log metrics to external systems (Weights & Biases, MLflow) alongside `ray.train.report()`.  
- **Larger models**: Integrate **DeepSpeed** or parameter-efficient fine-tuning templates.  
- **Productionization**: Store checkpoints in cloud storage (S3/GCS/Azure), wire up alerts/dashboards, and add CI for smoke tests.  

---

### 📚 Next Tutorials in the Course  
In the next tutorials, you’ll find **end-to-end workload examples** for using Ray Train on Anyscale (e.g., recommendation systems, vision, NLP, generative models).  

👉 You only need to pick **one** of these workloads to work through in the course — but you can explore more if you’re curious!  

---

> With these patterns—**distributed training**, **scalable data ingestion**, and **resilient recovery**—you’re ready to run larger, longer, and more reliable training jobs on Anyscale.  