## 15 · Remote Inference Helper  
Define a lightweight Ray remote function that loads a chosen checkpoint into a fresh `resnet18`, runs inference on one image, and returns both the predicted and true labels. Because the function requests one GPU, Ray schedules it on an appropriate node.

In [None]:

# 15. Define batch inference function

@ray.remote(num_gpus=1)
def run_inference_from_checkpoint(checkpoint_path, parquet_path, idx=0):

    # === Load model ===
    model = resnet18(num_classes=101)
    checkpoint = Checkpoint.from_directory(checkpoint_path)

    with checkpoint.as_directory() as ckpt_dir:
        state_dict = torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cuda")

        # Strip "module." prefix from distributed data parallelism trained weights
        state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}

        model.load_state_dict(state_dict)

    model.eval().cuda()

    # === Define transform ===
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

    # === Load dataset ===
    dataset = Food101Dataset(parquet_path, transform=transform)
    img, label = dataset[idx]
    img = img.unsqueeze(0).cuda()  # batch size 1

    with torch.no_grad():
        logits = model(img)
        pred = torch.argmax(logits, dim=1).item()

    return {"predicted_label": pred, "true_label": int(label), "index": idx}

### 16 · Run and Visualise Inference  
Using the checkpoint with the lowest validation loss, invoke your remote inference helper on a validation image. Then plot the image while displaying both the model’s prediction and the ground-truth class, giving you an immediate, intuitive sense of performance.

In [None]:
# 16. Perform inference with best trained model (i.e. lowest validation loss for a checkpointed model)
 
checkpoint_root = "/mnt/cluster_storage/food101_lite/results/food101_ft_resume"

checkpoint_dirs = sorted(
    [
        d for d in os.listdir(checkpoint_root)
        if d.startswith("checkpoint_") and os.path.isdir(os.path.join(checkpoint_root, d))
    ],
    reverse=True,
)

if not checkpoint_dirs:
    raise FileNotFoundError("No checkpoint directories found.")

with result.checkpoint.as_directory() as ckpt_dir:
    print("Best checkpoint contents:", os.listdir(ckpt_dir))
    best_ckpt_path = ckpt_dir 
parquet_path = "/mnt/cluster_storage/food101_lite/val.parquet"

# Define which image to use
idx = 2

# Launch GPU inference task with Ray
result = ray.get(run_inference_from_checkpoint.remote(best_ckpt_path, parquet_path, idx=idx))
print(result)

# Load label map from Hugging Face
ds = load_dataset("food101", split="train[:1%]")  # load just to get label names
label_names = ds.features["label"].names

# Load image from the same dataset locally
dataset = Food101Dataset(parquet_path, transform=None)  # No transform = raw image
img, _ = dataset[idx]

# Plot the image with predicted and true labels
plt.imshow(img)
plt.axis("off")
plt.title(f"Pred: {label_names[result['predicted_label']]}\nTrue: {label_names[result['true_label']]}")
plt.show()

### 17 · Cleanup  
Finally, tidy up by deleting temporary checkpoint folders, the metrics CSV, and any intermediate result directories. Clearing out old artefacts frees disk space and leaves your workspace clean for whatever comes next.

In [None]:
# 17. Cleanup---delete checkpoints and metrics from model training

# Base directory
BASE_DIR = "/mnt/cluster_storage/food101_lite"

# Paths to clean
paths_to_delete = [
    os.path.join(BASE_DIR, "tmp_checkpoints"),           # custom checkpoints
    os.path.join(BASE_DIR, "results", "history.csv"),    # metrics history file
    os.path.join(BASE_DIR, "results", "food101_ft_resume"),  # ray trainer run dir
    os.path.join(BASE_DIR, "results", "food101_ft_run"),
    os.path.join(BASE_DIR, "results", "food101_single_run"),
]

# Delete each path if it exists
for path in paths_to_delete:
    if os.path.exists(path):
        if os.path.isfile(path):
            os.remove(path)
            print(f"Deleted file: {path}")
        else:
            shutil.rmtree(path)
            print(f"Deleted directory: {path}")
    else:
        print(f"Not found (skipped): {path}")

### 🎉 Wrapping Up & Next Steps

Great job making it to the end. You've taken a realistic computer-vision workload, from raw images all the way to distributed training and GPU inference, and run it on Ray Train with zero boilerplate around GPUs, data parallelism, or fault-tolerance. You should now feel comfortable:

* Using **Ray Train’s TorchTrainer** to scale PyTorch training across multiple GPUs and nodes with minimal code changes  
* Wrapping models and data loaders with **`prepare_model()`** and **`prepare_data_loader()`** to enable Ray-managed device placement and distributed execution  
* Sharding data across workers using **`DistributedSampler`**, and coordinating training epochs across Ray workers  
* Configuring **automatic checkpointing and failure recovery** using Ray Train’s built-in `Checkpoint`, `RunConfig`, and `FailureConfig` APIs  
* Running **GPU-backed Ray tasks** for distributed inference, showing how to serve and scale model predictions across a Ray cluster  

---

### 🚀 Where can you take this next?

Below are a few directions you might explore to adapt or extend the pattern:

1. **Larger or Custom Datasets**  
   * Swap in the full 75 k-image Food-101 split—or your own dataset in any storage backend (S3, GCS, Azure Blob).  
   * Add multi-file Parquet sharding and let each worker read a different shard.

2. **Model Architectures**  
   * Drop in Vision Transformers (`vit_b_16`, `vit_l_32`) or ConvNeXt; the prepare helpers work exactly the same.  
   * Experiment with transfer learning vs. training from scratch.

3. **Mixed Precision & Performance Tuning**  
   * Enable automatic mixed precision (`torch.cuda.amp`) or bfloat16 to speed up training and save memory.  
   * Profile data-loading throughput and play with `num_workers`, prefetching, and caching.

4. **Hyperparameter Sweeps**  
   * Wrap the training loop in **Ray Tune** to search over learning rates, augmentations, or optimisers.  
   * Use Ray’s integrated reporting to schedule early stopping.

5. **Data Augmentation Pipelines**  
   * Integrate additional transforms inside the dataset class for image augmentation.  
   * Compare CPU vs. GPU-side augmentations for throughput.

6. **Distributed Validation & Metrics**  
   * Replace your simple accuracy printout with more advanced metrics (F1, top-5 accuracy, confusion matrices).  

7. **Model Serving**  
   * Convert the remote inference helper into a **Ray Serve** deployment for low-latency online predictions.  
   * Auto-scale replicas based on request volume.

8. **End-to-End MLOps**  
   * Register checkpoints in a model registry (For example, MLflow, Weights & Biases, or Ray’s built-in MLflow integration).  
   * Schedule the notebook as a Ray Job or CI/CD pipeline for regular retraining runs.
