## 15. Batch inference with Ray Data  

Define a **stateful, GPU-backed batch inference pipeline** using Ray Data.  
Each actor loads the model **once per GPU**, keeps it in memory, and performs inference on incoming batches in parallel.  
This pattern scales efficiently across multiple GPUs and avoids redundant model loading for every prediction.


In [None]:
# 15. Batch inference with Ray Data (force GPU actors if available on the cluster)

import ray.data as rdata

class ImageBatchPredictor:
    """Stateful per-actor batch predictor that keeps the model in memory."""
    def __init__(self, checkpoint_path: str):
        # Pick the best available device on the ACTOR (worker), not the driver.
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # === Load model & weights once per actor ===
        model = resnet18(num_classes=101)
        checkpoint = Checkpoint.from_directory(checkpoint_path)
        with checkpoint.as_directory() as ckpt_dir:
            state_dict = torch.load(
                os.path.join(ckpt_dir, "model.pt"),
                map_location=self.device,
            )
            # Strip DDP "module." prefix if present
            state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
            model.load_state_dict(state_dict)

        self.model = model.eval().to(self.device)
        self.transform = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
        torch.set_grad_enabled(False)

    def __call__(self, batch):
        """batch: Pandas DataFrame with columns ['image_bytes', 'label']"""
        imgs = []
        for b in batch["image_bytes"]:
            img = Image.open(io.BytesIO(b)).convert("RGB")
            imgs.append(self.transform(img).numpy())  # (C,H,W) as numpy
        x = torch.from_numpy(np.stack(imgs, axis=0)).to(self.device)  # (N,C,H,W)

        logits = self.model(x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()

        out = batch.copy()
        out["predicted_label"] = preds.astype(int)
        return out[["predicted_label", "label"]]

def build_inference_dataset(
    checkpoint_path: str,
    parquet_path: str,
    *,
    num_actors: int = 1,
    batch_size: int = 64,
    use_gpu_actors: bool = True,   # <— default to GPU actors on the cluster
):
    """
    Create a Ray Dataset pipeline that performs batch inference using
    stateful per-actor model loading. By default, requests 1 GPU per actor
    so each actor runs on a GPU worker (driver may have no GPU).
    """
    ds = rdata.read_parquet(parquet_path, columns=["image_bytes", "label"])

    pred_ds = ds.map_batches(
        ImageBatchPredictor,                     # pass the CLASS (stateful actors)
        fn_constructor_args=(checkpoint_path,),  # ctor args for each actor
        batch_size=batch_size,
        batch_format="pandas",
        concurrency=num_actors,                  # number of actor workers
        num_gpus=1 if use_gpu_actors else 0,     # <— force GPU placement on workers
    )
    return pred_ds

### 16. Run and visualize Ray Data inference  

Use the best checkpoint to run **Ray Data Inference** on a validation sample.  
The model loads once per GPU actor, predictions are batched and parallelized, and the result is visualized alongside the ground-truth label for quick qualitative evaluation.

In [None]:
# 16. Perform inference with Ray Data using the best checkpoint

checkpoint_root = "/mnt/cluster_storage/food101_lite/results/food101_ft_resume"

checkpoint_dirs = sorted(
    [
        d for d in os.listdir(checkpoint_root)
        if d.startswith("checkpoint_") and os.path.isdir(os.path.join(checkpoint_root, d))
    ],
    reverse=True,
)

if not checkpoint_dirs:
    raise FileNotFoundError("No checkpoint directories found.")

# Use the best checkpoint from the training result
with result.checkpoint.as_directory() as ckpt_dir:
    print("Best checkpoint contents:", os.listdir(ckpt_dir))
    best_ckpt_path = ckpt_dir

parquet_path = "/mnt/cluster_storage/food101_lite/val.parquet"

# Which item to visualize
idx = 2

# Build a Ray Data inference pipeline (model is loaded once per GPU actor)
pred_ds = build_inference_dataset(
    checkpoint_path=best_ckpt_path,
    parquet_path=parquet_path,
    num_actors=1,       # adjust to scale out
    batch_size=64,      # adjust for throughput
)

# Materialize predictions up to the desired index and grab the row
pred_rows = pred_ds.take(idx + 1)
inference_row = pred_rows[-1]  # {"predicted_label": ..., "label": ...}
print(inference_row)

# Load label map from Hugging Face (for pretty titles)
ds_tmp = load_dataset("food101", split="train[:1%]")  # just to get label names
label_names = ds_tmp.features["label"].names

# Load the raw image locally for visualization
dataset = Food101Dataset(parquet_path, transform=None)
img, _ = dataset[idx]

# Plot the image with predicted and true labels
plt.imshow(img)
plt.axis("off")
plt.title(
    f"Pred: {label_names[int(inference_row['predicted_label'])]}\n"
    f"True: {label_names[int(inference_row['label'])]}"
)
plt.show()

### 17. Clean up  
Finally, tidy up by deleting temporary checkpoint folders, the metrics CSV, and any intermediate result directories. Clearing out old artifacts frees disk space and leaves your workspace clean for whatever comes next.

In [None]:
# 17. Cleanup---delete checkpoints and metrics from model training

# Base directory
BASE_DIR = "/mnt/cluster_storage/food101_lite"

# Paths to clean
paths_to_delete = [
    os.path.join(BASE_DIR, "tmp_checkpoints"),           # custom checkpoints
    os.path.join(BASE_DIR, "results", "history.csv"),    # metrics history file
    os.path.join(BASE_DIR, "results", "food101_ft_resume"),  # ray trainer run dir
    os.path.join(BASE_DIR, "results", "food101_ft_run"),
    os.path.join(BASE_DIR, "results", "food101_single_run"),
]

# Delete each path if it exists
for path in paths_to_delete:
    if os.path.exists(path):
        if os.path.isfile(path):
            os.remove(path)
            print(f"Deleted file: {path}")
        else:
            shutil.rmtree(path)
            print(f"Deleted directory: {path}")
    else:
        print(f"Not found (skipped): {path}")

### Wrap up and next steps

You've taken a realistic computer-vision workload, from raw images all the way to distributed training and GPU inference, and run it on Ray Train with zero boilerplate around GPUs, data parallelism, or fault-tolerance. You should now feel comfortable:

* Using **Ray Train’s TorchTrainer** to scale PyTorch training across multiple GPUs and nodes with minimal code changes  
* Wrapping models and data loaders with **`prepare_model()`** and **`prepare_data_loader()`** to enable Ray-managed device placement and distributed execution  
* Sharding data across workers using and coordinating training epochs across Ray workers  
* Configuring **automatic checkpointing and failure recovery** using Ray Train’s built-in `Checkpoint`, `RunConfig`, and `FailureConfig` APIs  
* Running **Ray Data based Inference** for distributed inference, showing how to serve and scale model predictions across a Ray cluster  

---

### Where can you take this next?

Below are a few directions you might explore to adapt or extend the pattern:

1. **Larger or custom datasets**  
   * Swap in the full 75 k-image Food-101 split—or your own dataset in any storage backend (S3, GCS, Azure Blob).  
   * Add multi-file Parquet sharding and let each worker read a different shard.

2. **Model architectures**  
   * Drop in Vision Transformers (`vit_b_16`, `vit_l_32`) or ConvNeXt; the prepare helpers work exactly the same.  
   * Experiment with transfer learning versus training from scratch.

3. **Mixed precision and performance tuning**  
   * Enable automatic mixed precision (`torch.cuda.amp`) or bfloat16 to speed up training and save memory.  
   * Profile data-loading throughput and play with `num_workers`, prefetching, and caching.

4. **Hyperparameter sweeps**  
   * Wrap the training loop in **Ray Tune** to search over learning rates, augmentations, or optimizers.  
   * Use Ray’s integrated reporting to schedule early stopping.

5. **Data augmentation pipelines**  
   * Integrate additional transforms inside the dataset class for image augmentation.  
   * Compare CPU versus GPU-side augmentations for throughput.

6. **Distributed validation and metrics**  
   * Replace your simple accuracy printout with more advanced metrics (F1, top-5 accuracy, confusion matrices).  

7. **Model serving**  
   * Convert the remote inference helper into a **Ray Serve** deployment for low-latency online predictions.  
   * Auto-scale replicas based on request volume.

8. **End-to-end MLOps**  
   * Register checkpoints in a model registry (for example, MLflow, Weights & Biases, or Ray’s built-in MLflow integration).  
   * Schedule the notebook as a Ray Job or CI/CD pipeline for regular retraining runs.
