## Importing Required Libraries

This cell imports the key libraries needed for distributed deep learning and image processing:

- `ray`: Enables scalable distributed computing for model training and data processing.
- `torch`: Core PyTorch package for building and training neural networks.
- `torchvision.models` and `torchvision.transforms`: Provides access to pretrained models (like ResNet152) and common image transformations.
- `PIL.Image`: Supports image loading and manipulation using the Pillow library.
- `numpy`: Essential package for efficient numerical operations and array manipulation.

In [None]:
import ray 
import torch 
from torchvision.models import resnet152, ResNet152_Weights
from torchvision import transforms
from PIL import Image
import numpy as np

### Loading a Sample Image Batch from S3 with Ray Data

This cell demonstrates how to efficiently load image data from an S3 bucket using Ray Data:

- Reads images from an S3 URI in RGB mode into a Ray Dataset.
- Subsets the dataset to the first 1,000 images for quick experimentation.
- Retrieves a batch of 3 images as a dictionary of arrays.
- Converts the first image in the batch from a NumPy array to a PIL Image for display.

In [None]:
s3_uri = "s3://anonymous@air-example-data-2/imagenette2/train/"
ds = ray.data.read_images(s3_uri, mode="RGB")

In [None]:
subset_ds = ds.limit(1000)

In [None]:
single_batch = subset_ds.take_batch(3)

In [None]:
img = Image.fromarray(single_batch["image"][0])
img

In [None]:
single_batch.keys()

### Applying Preprocessing Transforms to Images from Ray Dataset

This section defines and applies preprocessing steps to each image in the Ray Dataset:

- Loads the standard ImageNet normalisation and resizing transforms for ResNet-152.
- Combines the transforms with `ToTensor()` using `transforms.Compose`.
- Defines a `preprocess_image` function to:
  - Store the original image.
  - Apply the full transform pipeline, saving the result as `"transformed_image"`.
- Applies this function to every row in the dataset with `.map()`, creating a new dataset with both original and transformed images.
- Retrieves two samples as a batch and displays the available keys and the shape of the transformed images.

In [None]:
weights = ResNet152_Weights.IMAGENET1K_V1
imagenet_transforms = weights.transforms
transform = transforms.Compose([transforms.ToTensor(), imagenet_transforms()])

def preprocess_image(row: dict[str, np.ndarray]):
    return {
        "original_image": row["image"],
        "transformed_image": transform(row["image"]),
    }

In [None]:
transformed_ds = ds.map(preprocess_image)

In [None]:
two_batches = transformed_ds.take_batch(2)

In [None]:
print(f"Batch is a dictionary with the following keys : {two_batches.keys()}")

In [None]:
two_batches["transformed_image"].shape

### Batch Inference Class for Distributed Prediction with ResNet-152

This class encapsulates batch inference logic for ResNet-152 using PyTorch and Ray Data:

- **Initialisation**:  
  - Loads ResNet-152 with pretrained ImageNet weights.
  - Sets the model to evaluation mode for inference.
- **Call Method**:  
  - Accepts a batch of preprocessed images (as NumPy arrays).
  - Converts the batch to a PyTorch tensor and moves it to the appropriate device.
  - Performs inference in no-grad mode for efficiency.
  - Returns both the predicted labels and the original images.

> **Note:**  
> When used with Ray Data’s `.map_batches()` or `.map()`, an instance of this class is automatically created on each Ray worker (CPU or GPU).  
> Ray Data manages distributed execution and device placement—you do **not** need to use `@ray.remote` on this class.

In [None]:
class BatchInferenceResNet:
    def __init__(self):
        self.weights = ResNet152_Weights.IMAGENET1K_V1
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = resnet152(weights=self.weights).to(self.device)
        self.model.eval()

    def __call__(self, batch: dict[str, np.ndarray]):
        torch_batch = torch.from_numpy(batch["transformed_image"]).to(self.device)
        with torch.inference_mode():
            prediction = self.model(torch_batch)
            predicted_classes = prediction.argmax(dim=1).detach().cpu()
            predicted_labels = [
                self.weights.meta["categories"][i] for i in predicted_classes
            ]
            return {
                "predicted_label": predicted_labels,
                "original_image": batch["original_image"],
            }

In [None]:
predictions = transformed_ds.map_batches(
    BatchInferenceResNet,
    concurrency=4,  
    #num_gpus=1,  
    batch_size=10,
)
prediction_batch = predictions.take_batch(5)

In [None]:
for image, prediction in zip(
    prediction_batch["original_image"], prediction_batch["predicted_label"]
):
    img = Image.fromarray(image)
    display(img)
    print("Label: ", prediction)