## Pytorch + Ray Train Distributed Training
This cell imports all the necessary libraries and modules used throughout the notebook:

- **Standard library**: Utilities for file paths, temporary directories, unique IDs, and file locking.
- **Third-party libraries**: 
  - `filelock` for safe concurrent file access.
  - `PIL.Image` for image processing.
  - `matplotlib.pyplot` for data visualization.
  - `torch`, `torch.nn`, `torch.utils.data`, and `torchmetrics` for building, training, and evaluating PyTorch models.
- **torchvision**: 
  - Datasets (CIFAR-10), pretrained models (ResNet-18), and common data transforms for image preprocessing.
- **Ray**: 
  - Distributed training utilities including `TorchTrainer` for scalable training, and configuration classes for managing Ray training jobs.

In [None]:
# Standard library
import os
import tempfile
import uuid
from pathlib import Path

# Third-party libraries
from filelock import FileLock
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchmetrics.classification import Accuracy

# torchvision (separate for clarity)
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms

# Ray imports
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig, CheckpointConfig

### Loading and Inspecting the CIFAR-10 Test Dataset

This cell downloads the CIFAR-10 test set and examines a single sample.

- `CIFAR10(root=..., download=True, train=False)`: Downloads (if needed) and loads the test split of CIFAR-10. Each sample is a tuple: `(PIL image, label)`.
- `next(iter(data))`: Retrieves the first image-label pair from the dataset.
- `plt.imshow(next(iter(data))[0])`: Displays the image from the first sample using matplotlib.

In [None]:
data = CIFAR10(root="../marimo_notebooks/data",download=True,train=False)
data

In [None]:
# the data contains of a PIL image and the label
next(iter(data))

In [None]:
plt.imshow(next(iter(data))[0])

In [None]:
class_to_idx = data.class_to_idx
class_to_idx

### Loading a Pretrained ResNet-18 Model

This cell demonstrates how to load the pretrained ResNet-18 model from torchvision:

- `ResNet18_Weights.IMAGENET1K_V1`: Specifies the standard pretrained weights trained on ImageNet.
- `resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)`: Loads a ResNet-18 model instance with the pretrained weights applied.

Printing `model` displays the architecture.

In [None]:
weights = ResNet18_Weights.IMAGENET1K_V1

In [None]:
model = resnet18(weights=ResNet18_Weights)
model

### Defining the CIFAR-10 DataLoader with Preprocessing

This function prepares PyTorch DataLoaders for the CIFAR-10 dataset, using the same image transforms as those expected by a pretrained ResNet-18 model:

- Retrieves the ResNet-18 ImageNet normalisation and resizing transforms with `weights.transforms()`.
- Downloads the CIFAR-10 training and test sets (if not already present), applying the ResNet-18 transforms to each image.
- Optionally, subsets the data to the first 1,000 samples for quicker experimentation.
- Wraps the datasets with PyTorch `DataLoader` objects to enable mini-batch loading and shuffling during training and evaluation.

Returns:  
`train_dataloader`, `valid_dataloader` — ready-to-use for model training and validation.

In [None]:
def get_cifar_dataloader(batch_size):
    weights = ResNet18_Weights.IMAGENET1K_V1
    resnet18_transforms = weights.transforms()
    with FileLock(os.path.expanduser("~/cifar_data.lock")):
        train = CIFAR10(
            root="~/cifar_data",
            train=True,
            download=True,
            transform=resnet18_transforms,
        )
        valid = CIFAR10(
            root="~/cifar_data",
            train=False,
            download=True,
            transform=resnet18_transforms,
        )
    train_sub = Subset(train,indices=range(1000))
    valid_sub = Subset(valid,indices=range(1000))
    # dataloaders to get data in batches
    train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid, batch_size=batch_size)

    return train_dataloader, valid_dataloader



In [None]:
sample_dataloader = next(iter(get_cifar_dataloader(3)))
single_batch =  next(iter(sample_dataloader))

In [None]:
single_batch[1].shape

In [None]:
single_batch[0].shape

### Ray Distributed Training Function for Transfer Learning on CIFAR-10

This function implements the distributed training loop for fine-tuning a ResNet-18 model on the CIFAR-10 dataset using Ray Train. The workflow includes:

- **Configuration**: Reads training hyperparameters (epochs, batch size, learning rate, weight decay, and number of classes) from `config`.
- **Device Setup**: Automatically selects GPU (`cuda`), or CPU.
- **Model Preparation**:
  - Loads a ResNet-18 model pre-trained on ImageNet.
  - Freezes all layers except the final fully-connected (`fc`) layer, which is replaced to match the number of CIFAR-10 classes.
  - Wraps the model for distributed training with Ray.
- **Optimizer and Loss**: Uses AdamW for optimization and cross-entropy as the loss function.
- **Data Loading**:
  - Prepares training and validation DataLoaders with transforms compatible with ResNet-18.
  - Ensures data is properly sharded for distributed execution.
- **Training Loop**:
  - Trains for the specified number of epochs, accumulating average loss and accuracy for both training and validation sets.
  - Evaluates on the validation set after each epoch (inference mode, no gradients).
- **Metrics and Checkpointing**:
  - Reports epoch-level training and validation metrics (loss, accuracy) to Ray for centralized tracking.
  - Saves a model checkpoint at each epoch for potential recovery or analysis.
  - Prints metrics from the main worker for logging and monitoring.

In [None]:
def train_func(config):

    epochs = config["epochs"]
    batch_size = config["batch_size"]
    lr = config["lr"]
    
    
    # use detected device
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

    # metrics
    accuracy = Accuracy(task="multiclass", num_classes=config["num_classes"]).to(device)
    
    weights = ResNet18_Weights.IMAGENET1K_V1


    # freezing the all weights except the last one
    model = resnet18(weights=weights)
    for parameter in model.parameters():
        parameter.requires_grad = False

    # outputs 10 classes
    model.fc = nn.Linear(512,config["num_classes"],bias=True)
    # prepare model
    model = ray.train.torch.prepare_model(model)
 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

    train_dataloader, valid_dataloader = get_cifar_dataloader(batch_size=batch_size)
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    valid_dataloader = ray.train.torch.prepare_data_loader(valid_dataloader)

    for epoch in range(epochs):
        # checking if training is scheduled in a distributed setting or not.
        if ray.train.get_context().get_world_size() > 1:
            train_dataloader.sampler.set_epoch(epoch)
        train_loss = 0.0
        train_acc = 0.0
        model.train()
        for idx, batch in enumerate(train_dataloader):
            x, y = batch[0], batch[1]
            y_preds = model(x)
            y_labels = y_preds.argmax(dim=1)
            loss = loss_fn(y_preds,y)
            acc = accuracy(y_labels,y)
            train_loss +=  loss.item()
            train_acc += acc.item()
        train_loss /=len(train_dataloader)
        train_acc /=len(train_dataloader)

        valid_loss = 0.0
        valid_acc = 0.0
        model.eval()
        with torch.inference_mode():
            for idx, batch in enumerate(valid_dataloader):
                x, y = batch[0], batch[1]
                y_preds = model(x)
                y_labels = y_preds.argmax(dim=1)
                loss = loss_fn(y_preds,y)
                acc = accuracy(y_labels,y)
                valid_loss +=  loss.item()
                valid_acc += acc.item()
                valid_loss /=len(valid_dataloader)
                valid_acc /=len(valid_dataloader)
        
        metrics = {"epoch":epoch,"train_loss":train_loss, "train_acc":train_acc,
                  "valid_loss":valid_loss, "valid_acc":valid_acc}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )

        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

### Configuring Ray TorchTrainer for Distributed Training

This section sets up the Ray TorchTrainer to launch distributed training with the following components:

- **Global Training Parameters**: 
  - Defines the overall batch size, number of parallel workers, and whether to use GPUs.
  - Calculates per-worker batch size.

- **Training Configuration (`train_config`)**:
  - Specifies learning rate, number of epochs, number of classes, per-worker batch size, and weight decay for the optimizer.

- **Scaling Configuration (`scaling_config`)**:
  - Sets the number of workers and whether to use GPUs for distributed training.

- **Run Configuration (`run_config`)**:
  - Specifies the experiment's storage location for checkpoints and logs.
  - Assigns a unique run name for tracking.
  - Configures checkpointing behavior (e.g., how many to keep, checkpoint selection criteria).

- **Trainer Instantiation**:
  - Creates a `TorchTrainer` object with all the above configs, ready to launch the training loop (`train_func`) in parallel across the specified workers.

In [None]:
global_batch_size = 100
num_workers = 9
use_gpu = False

train_config = {
    "lr": 0.01,
    "epochs": 3,
    "num_classes": 10,
    "batch_size": global_batch_size // num_workers,
    "weight_decay": 0.02
}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
    storage_path=str(Path("../marimo_notebooks/data/storage_path").resolve()), # local_testing
    #storage_path="/mnt/cluster_storage", # we could use s3 as well
    name=f"ray_train_torch_run-{uuid.uuid4().hex}",
    checkpoint_config = CheckpointConfig(num_to_keep=1,
    checkpoint_score_attribute="train_acc",
    checkpoint_score_order="max",) 
)

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [None]:
# fitting the model
result = trainer.fit()
print(f"Training result: {result}")

In [None]:
ray.shutdown()