## 1. Imports and setup  
Pull in standard Python utilities, Ray (core, Data, Train, Lightning), and PyTorch Lightning. 

In [None]:
# 00. Runtime setup
import os, sys, subprocess

# Non-secret env var (safe to set here)
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"

# Install Python dependencies (same pinned versions as build.sh)
subprocess.check_call([
    sys.executable, "-m", "pip", "install", "--no-cache-dir",
    "torch==2.8.0",
    "torchvision==0.23.0",
    "matplotlib==3.10.6",
    "pyarrow==14.0.2",
    "datasets==2.19.2",
    "lightning==2.5.5",
])

In [None]:
# 01. Imports

# Standard libraries
import os, io, json, shutil, tempfile
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

# Ray
import ray, ray.data
from ray.train import ScalingConfig, get_context, RunConfig, FailureConfig, CheckpointConfig, Checkpoint, get_checkpoint
from ray.train.torch import TorchTrainer
from ray.train.lightning import RayLightningEnvironment

# PyTorch / Lightning
import lightning.pytorch as pl
import torch
from torch import nn

# Dataset
from datasets import load_dataset
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm  
from torchvision.transforms import Compose, Resize, CenterCrop
import random

### 2. Load 10 % of Food-101  
Next, grab roughly 7 500 images, exactly 10 % of Food-101—using a single call to `load_dataset`. This trimmed subset trains quickly while still being large enough to demonstrate Ray’s scaling behaviour.

In [None]:
# 02. Load 10% of food101 (~7,500 images)
hf_ds = load_dataset("food101", split="train[:10%]") 

### 3. Resize and encode images  
Use **Ray Data** to preprocess images in parallel across the cluster.  
Convert each image to raw JPEG bytes (a serializable format) and then decoded, resized to 256 pixels, center-cropped to 224 pixels, and re-encoded.  
Processing with Ray Data makes the pipeline distributed, fault-tolerant, and Parquet-friendly—keeping the dataset compact while scaling efficiently across workers.

In [None]:
# 03. Resize + encode as JPEG bytes (Ray Data; BYTES-BASED)

# Build Ray items with RAW BYTES (serializable) + label
rows = []
buf = io.BytesIO()
for ex in hf_ds:
    img = ex["image"].convert("RGB")
    buf.seek(0); buf.truncate(0)
    img.save(buf, format="JPEG")
    rows.append({"image_bytes_raw": buf.getvalue(), "label": ex["label"]})

# Create a Ray Dataset from serializable dicts
ds = ray.data.from_items(rows)

# Define preprocessing (runs on Ray workers)
transform = Compose([Resize(256), CenterCrop(224)])

def preprocess_images(batch_df):
    out_img_bytes, out_labels = [], []
    for b, lbl in zip(batch_df["image_bytes_raw"], batch_df["label"]):
        try:
            img = Image.open(io.BytesIO(b)).convert("RGB")
            img = transform(img)
            out = io.BytesIO()
            img.save(out, format="JPEG")
            out_img_bytes.append(out.getvalue())
            out_labels.append(lbl)
        except Exception:
            # Skip unreadable/corrupt rows but don't kill the batch
            continue
    return {"image_bytes": out_img_bytes, "label": out_labels}

# Parallel preprocessing
processed_ds = ds.map_batches(
    preprocess_images,
    batch_format="pandas",
    num_cpus=1,
)

print("✅ Processed records:", processed_ds.count())
processed_ds.show(3)

### 4. Visual sanity check  
Before committing to hours of training, take nine random samples and plot them with their class names. This quick inspection lets you confirm that images are correctly resized and preprocessed.

In [None]:
# 04. Visualize the dataset (Ray Data version)
label_names = hf_ds.features["label"].names  # int -> class name

samples = processed_ds.random_shuffle().take(9)

fig, axs = plt.subplots(3, 3, figsize=(8, 8))
fig.suptitle("Sample Resized Images from food101-lite", fontsize=16)

for ax, rec in zip(axs.flatten(), samples):
    img = Image.open(io.BytesIO(rec["image_bytes"]))
    ax.imshow(img)
    ax.set_title(label_names[rec["label"]])
    ax.axis("off")

plt.tight_layout()
plt.show()

### 5. Persist to Parquet  
Now, write the images and labels to a Parquet file. Because Parquet is columnar, you can read just the columns you need during training, which speeds up IO---especially when multiple workers are reading in parallel under Ray.

In [None]:
# 05. Persist Ray Dataset to Parquet
import os

output_dir = "/mnt/cluster_storage/food101_lite/parquet_256"
os.makedirs(output_dir, exist_ok=True)

# Write each block as its own Parquet shard
processed_ds.write_parquet(output_dir)

print(f"✅ Wrote {processed_ds.count()} records to {output_dir}")

### 6. Load and decode with Ray Data  
Read the Parquet shard into a **Ray Dataset**, decode the JPEG bytes to **Channel-Height-Width (CHW) float32 tensors**, scale to \[-1, 1\], and drop the original byte column.  
Because `decode_and_normalize` is stateless, the default **task-based** execution is perfect.

In [None]:
# 06. Load & Decode Food-101-Lite

# Path to Parquet shards written earlier
PARQUET_PATH = "/mnt/cluster_storage/food101_lite/parquet_256"

# Read the Parquet files (≈7 500 rows with JPEG bytes + label)
ds = ray.data.read_parquet(PARQUET_PATH)
print("Raw rows:", ds.count())

# Decode JPEG → CHW float32 in [‑1, 1]

def decode_and_normalize(batch_df):
    """Decode JPEG bytes and scale to [-1, 1]."""
    images = []
    for b in batch_df["image_bytes"]:
        img = Image.open(io.BytesIO(b)).convert("RGB")
        arr = np.asarray(img, dtype=np.float32) / 255.0       # H × W × 3, 0‑1
        arr = (arr - 0.5) / 0.5                               # ‑1 … 1
        arr = arr.transpose(2, 0, 1)                          # 3 × H × W (CHW)
        images.append(arr)
    return {"image": images}

# Apply in parallel
#   batch_format="pandas" → batch_df is a DataFrame, return dict of lists.
#   default task‑based compute is sufficient for a stateless function.

ds = ds.map_batches(
    decode_and_normalize,
    batch_format="pandas",
    # Use the default (task‑based) compute strategy since `decode_and_normalize` is a plain function.
    num_cpus=1,
)

# Drop the original JPEG column to save memory
if "image_bytes" in ds.schema().names:
    ds = ds.drop_columns(["image_bytes", "label"])

print("Decoded rows:", ds.count())

### 7. Shuffle and Train/Val split  
Perform a reproducible shuffle, then split 80 % / 20 % into `train_ds` and `val_ds`.  
Each split remains a first-class Ray Dataset, enabling distributed, sharded DataLoaders later on.

In [None]:
# 07. Shuffle & Train/Val Split

# Typical 80 / 20 split
TOTAL = ds.count()
train_count = int(TOTAL * 0.8)
ds = ds.random_shuffle()  # expensive operation -- for large datasets, consider file shuffling or local shuffling. Ray offers both options
train_ds, val_ds = ds.split_at_indices([train_count])
print("Train rows:", train_ds.count())
print("Val rows:",   val_ds.count())