## 01 · Imports & Setup  
Pull in standard Python utilities, Ray (core, Data, Train, Lightning), and PyTorch Lightning.  
Make sure you set the Anyscale cluster to Ray ≥ 2.48, so you get Ray Train V2 semantics automatically enabled.

In [None]:
# 01. Imports

# Standard libraries
import os, io, json, shutil
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

# Ray
import ray, ray.data
from ray.train import ScalingConfig, get_context, RunConfig, FailureConfig, CheckpointConfig, Checkpoint, get_checkpoint
from ray.train.torch import TorchTrainer
from ray.train.lightning import RayLightningEnvironment

# PyTorch / Lightning
import lightning.pytorch as pl
import torch
from torch import nn

# Dataset
from datasets import load_dataset
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm  
from torchvision.transforms import Compose, Resize, CenterCrop
import random

### 02 · Load 10 % of Food-101  
Next, grab roughly 7 500 images, exactly 10 % of Food-101—using a single call to `load_dataset`. This trimmed subset trains quickly while still being large enough to demonstrate Ray’s scaling behaviour.

NOTE: skip cells 02-05 if the dataset is already downloaded, as this is the same dataset as in tutorial 04a.

In [None]:
# 02. Load 10% of food101 (~7,500 images)
ds = load_dataset("food101", split="train[:10%]") 

### 03 · Resize and Encode Images  
Preprocess each image: resize to 256 pixel, center-crop to 224 pixel (the size expected by most ImageNet models), and then convert the result to raw JPEG bytes. By storing bytes instead of full Python Imaging Library (PIL) objects, you keep the dataset compact and Parquet-friendly.

In [None]:
# 03. Resize + encode as JPEG bytes
transform = Compose([Resize(256), CenterCrop(224)])
records = []

for example in tqdm(ds, desc="Preprocessing images", unit="img"):
    try:
        img = transform(example["image"])
        buf = io.BytesIO()
        img.save(buf, format="JPEG")
        records.append({
            "image_bytes": buf.getvalue(),
            "label": example["label"]
        })
    except Exception as e:
        continue

### 04 · Visual Sanity Check  
Before committing to hours of training, take nine random samples and plot them with their class names. This quick inspection lets you confirm that images are correctly resized and preprocessed.

In [None]:
# 04. Visualize the dataset

label_names = ds.features["label"].names  # maps int → string

samples = random.sample(records, 9)

fig, axs = plt.subplots(3, 3, figsize=(8, 8))
fig.suptitle("Sample Resized Images from food101-lite", fontsize=16)

for ax, rec in zip(axs.flatten(), samples):
    img = Image.open(io.BytesIO(rec["image_bytes"]))
    label_name = label_names[rec["label"]]
    ax.imshow(img)
    ax.set_title(label_name)
    ax.axis("off")

plt.tight_layout()
plt.show()

### 05 · Persist to Parquet  
Now, write the images and labels to a Parquet file. Because Parquet is columnar, you can read just the columns you need during training, which speeds up IO---especially when multiple workers are reading in parallel under Ray.

In [None]:
# 05. Write Dataset to Parquet

output_dir = "/mnt/cluster_storage/food101_lite/parquet_256"
os.makedirs(output_dir, exist_ok=True)

table = pa.Table.from_pydict({
    "image_bytes": [r["image_bytes"] for r in records],
    "label": [r["label"] for r in records]
})
pq.write_table(table, os.path.join(output_dir, "shard_0.parquet"))

print(f"Wrote {len(records)} records to {output_dir}")

### 06 · Load & Decode with Ray Data  
Read the Parquet shard into a **Ray Dataset**, decode the JPEG bytes to ** Channel-Height-Width (CHW) float32 tensors**, scale to \[-1, 1\], and drop the original byte column.  
Because `decode_and_normalize` is stateless, the default **task-based** execution is perfect.

In [None]:
# 06. Load & Decode Food-101-Lite

# Path to Parquet shards written earlier
PARQUET_PATH = "/mnt/cluster_storage/food101_lite/parquet_256"

# Read the Parquet files (≈7 500 rows with JPEG bytes + label)
ds = ray.data.read_parquet(PARQUET_PATH)
print("Raw rows:", ds.count())

# Decode JPEG → CHW float32 in [‑1, 1]

def decode_and_normalize(batch_df):
    """Decode JPEG bytes and scale to [-1, 1]."""
    images = []
    for b in batch_df["image_bytes"]:
        img = Image.open(io.BytesIO(b)).convert("RGB")
        arr = np.asarray(img, dtype=np.float32) / 255.0       # H × W × 3, 0‑1
        arr = (arr - 0.5) / 0.5                               # ‑1 … 1
        arr = arr.transpose(2, 0, 1)                          # 3 × H × W (CHW)
        images.append(arr)
    return {"image": images}

# Apply in parallel
#   batch_format="pandas" → batch_df is a DataFrame, return dict of lists.
#   default task‑based compute is sufficient for a stateless function.

ds = ds.map_batches(
    decode_and_normalize,
    batch_format="pandas",
    # Use the default (task‑based) compute strategy since `decode_and_normalize` is a plain function.
    num_cpus=1,
)

# Drop the original JPEG column to save memory
if "image_bytes" in ds.schema().names:
    ds = ds.drop_columns(["image_bytes", "label"])

print("Decoded rows:", ds.count())

### 07 · Shuffle & Train/Val Split  
Perform a reproducible shuffle, then split 80 % / 20 % into `train_ds` and `val_ds`.  
Each split remains a first-class Ray Dataset, enabling distributed, sharded DataLoaders later on.

In [None]:
# 07. Shuffle & Train/Val Split

# Typical 80 / 20 split
TOTAL = ds.count()
train_count = int(TOTAL * 0.8)
ds = ds.random_shuffle()
train_ds, val_ds = ds.split_at_indices([train_count])
print("Train rows:", train_ds.count())
print("Val rows:",   val_ds.count())