# Distributed training

<div align="left">
<a target="_blank" href="https://console.anyscale.com/"><img src="https://raw.githubusercontent.com/ray-project/ray/c34b74c22a9390aa89baf80815ede59397786d2e/doc/source/_static/img/run-on-anyscale.svg"></a>&nbsp;

<a href="https://github.com/anyscale/multimodal-ai" role="button"><img src="https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d"></a>&nbsp;
</div>

This tutorial executes a distributed training workload that connects the following heterogeneous workloads:
- preprocess the dataset prior to training
- distributed training with Ray Train and PyTorch with observability
- evaluation (batch inference and eval logic)
- save model artifacts to a model registry (MLOps)

**Note**: this tutorial doesn't tune the model but see [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for experiment execution and hyperparameter tuning at any scale.

<img src="https://raw.githubusercontent.com/anyscale/multimodal-ai/refs/heads/main/images/distributed_training.png" width=1000>

In [None]:
%%bash
pip install -q -r /home/ray/default/requirements.txt
pip install -q -e /home/ray/default/doggos


[92mSuccessfully registered `ipywidgets, matplotlib` and 4 other packages to be installed on all cluster nodes.[0m
[92mView and update dependencies here: https://console.anyscale.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_cz951f43jjdybtzkx1s5sjgz99/workspaces/expwrk_23ry3pgfn3jgq2jk3e5z25udhz?workspace-tab=dependencies[0m
[92mSuccessfully registered `doggos` package to be installed on all cluster nodes.[0m
[92mView and update dependencies here: https://console.anyscale.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_cz951f43jjdybtzkx1s5sjgz99/workspaces/expwrk_23ry3pgfn3jgq2jk3e5z25udhz?workspace-tab=dependencies[0m


**Note**: A kernel restart may be required for all dependencies to become available. 

If using **uv**, then:
1. Turn off the runtime dependencies (`Dependencies` tab up top > Toggle off `Pip packages`). And no need to run the `pip install` commands above.
2. Change the python kernel of this notebook to use the `venv` (Click on `base (Python x.yy.zz)` on top right cordern of notebook > `Select another Kernel` > `Python Environments...` > `Create Python Environment` > `Venv` > `Use Existing`) and done! Now all the notebook's cells will use the virtual env.
3. Change the py executable to use `uv run` instead of `python` by adding this line after importing ray.
```python
import os
os.environ.pop("RAY_RUNTIME_ENV_HOOK", None)
import ray
ray.init(runtime_env={"py_executable": "uv run", "working_dir": "/home/ray/default"})
```

In [None]:
%load_ext autoreload
%autoreload all


In [None]:
import os
import ray
import sys
sys.path.append(os.path.abspath("../doggos/"))


In [None]:
# If using UV
# os.environ.pop("RAY_RUNTIME_ENV_HOOK", None)


In [None]:
# Enable Ray Train v2. It's too good to wait for public release!
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
ray.init(
    # connect to existing ray runtime (from previous notebook if still running)
    address=os.environ.get("RAY_ADDRESS", "auto"),
    runtime_env={
        "env_vars": {"RAY_TRAIN_V2_ENABLED": "1"},
        # "py_executable": "uv run", # if using uv 
        # "working_dir": "/home/ray/default",  # if using uv 
    },
)


2025-08-28 05:06:48,041	INFO worker.py:1771 -- Connecting to existing Ray cluster at address: 10.0.17.148:6379...
2025-08-28 05:06:48,052	INFO worker.py:1942 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-jhxhj69d6ttkjctcxfnsfe7gwk.i.anyscaleuserdata.com [39m[22m
2025-08-28 05:06:48,061	INFO packaging.py:588 -- Creating a file package for local module '/home/ray/default/doggos/doggos'.
2025-08-28 05:06:48,064	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_86cc12e3f2760ca4.zip' (0.03MiB) to Ray cluster...
2025-08-28 05:06:48,065	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_86cc12e3f2760ca4.zip'.
2025-08-28 05:06:48,068	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_563e3191c4f9ed5f5d5e8601702cfa5ff10660e4.zip' (1.09MiB) to Ray cluster...
2025-08-28 05:06:48,073	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_563e3191c4f9ed5f5d5e8601702cfa5ff10660e4.zip'.


0,1
Python version:,3.12.11
Ray version:,2.49.0
Dashboard:,http://session-jhxhj69d6ttkjctcxfnsfe7gwk.i.anyscaleuserdata.com


In [None]:
%%bash
# This will be removed once Ray Train v2 is enabled by default.
echo "RAY_TRAIN_V2_ENABLED=1" > /home/ray/default/.env


In [None]:
# Load env vars in notebooks.
from dotenv import load_dotenv
load_dotenv()


True

## Preprocess

You need to convert the classes to labels (unique integers) so that you can train a classifier that can correctly predict the class given an input image. But before you do this, apply the same data ingestion and preprocessing as the previous notebook.

In [None]:
def add_class(row):
    row["class"] = row["path"].rsplit("/", 3)[-2]
    return row


In [None]:
# Preprocess data splits.
train_ds = ray.data.read_images("s3://doggos-dataset/train", include_paths=True, shuffle="files")
train_ds = train_ds.map(add_class)
val_ds = ray.data.read_images("s3://doggos-dataset/val", include_paths=True)
val_ds = val_ds.map(add_class)


Define a `Preprocessor` class that:
- creates an embedding. A later step moves the embedding layer outside of the model since you freeze the embedding layer's weights and so you don't have to do it repeatedly as part of the model's forward pass, saving on unnecessary compute.
- convert the classes into labels for the classifier. 

While you could've just done this step as a simple operation, you're taking the time to organize it as a class so that you can save and load for inference later.

In [None]:
def convert_to_label(row, class_to_label):
    if "class" in row:
        row["label"] = class_to_label[row["class"]]
    return row


In [None]:
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
from doggos.embed import EmbedImages


In [None]:
class Preprocessor:
    """Preprocessor class."""
    def __init__(self, class_to_label=None):
        self.class_to_label = class_to_label or {}  # mutable defaults
        self.label_to_class = {v: k for k, v in self.class_to_label.items()}
        
    def fit(self, ds, column):
        self.classes = ds.unique(column=column)
        self.class_to_label = {tag: i for i, tag in enumerate(self.classes)}
        self.label_to_class = {v: k for k, v in self.class_to_label.items()}
        return self
    
    def transform(self, ds, concurrency=4, batch_size=64, num_gpus=1):
        ds = ds.map(
            convert_to_label, 
            fn_kwargs={"class_to_label": self.class_to_label},
        )
        ds = ds.map_batches(
            EmbedImages,
            fn_constructor_kwargs={
                "model_id": "openai/clip-vit-base-patch32", 
                "device": "cuda",
            },
            concurrency=4,
            batch_size=64,
            num_gpus=1,
            accelerator_type="T4",
        )
        ds = ds.drop_columns(["image"])
        return ds

    def save(self, fp):
        with open(fp, "w") as f:
            json.dump(self.class_to_label, f)


In [None]:
# Preprocess.
preprocessor = Preprocessor()
preprocessor = preprocessor.fit(train_ds, column="class")
train_ds = preprocessor.transform(ds=train_ds)
val_ds = preprocessor.transform(ds=val_ds)


2025-08-28 05:06:54,182	INFO dataset.py:3248 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-08-28 05:06:54,184	INFO logging.py:295 -- Registered dataset logger for dataset dataset_14_0


2025-08-28 05:06:54,206	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_14_0. Full logs are in /tmp/ray/session_2025-08-28_04-57-43_348032_12595/logs/ray-data
2025-08-28 05:06:54,207	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_14_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class) 3: 0.00 row [00:00, ? row/s]

- Aggregate 4: 0.00 row [00:00, ? row/s]

Sort Sample 5:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 8: 0.00 row [00:00, ? row/s]

2025-08-28 05:07:03,480	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_14_0 execution finished in 9.27 seconds


<div class="alert alert-block alert"> <b> Data processing</b> 

See this extensive guide on [data loading and preprocessing](https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html) for the last-mile preprocessing you need to do prior to training your models. However, Ray Data does support performant joins, filters, aggregations, etc., for the more structure data processing your workloads may need.

In [None]:
import shutil


In [None]:
# Write processed data to cloud storage.
preprocessed_data_path = os.path.join("/mnt/cluster_storage", "doggos/preprocessed_data")
if os.path.exists(preprocessed_data_path):  # Clean up.
    shutil.rmtree(preprocessed_data_path)
preprocessed_train_path = os.path.join(preprocessed_data_path, "preprocessed_train")
preprocessed_val_path = os.path.join(preprocessed_data_path, "preprocessed_val")
train_ds.write_parquet(preprocessed_train_path)
val_ds.write_parquet(preprocessed_val_path)


2025-08-28 05:07:04,254	INFO logging.py:295 -- Registered dataset logger for dataset dataset_22_0
2025-08-28 05:07:04,270	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_22_0. Full logs are in /tmp/ray/session_2025-08-28_04-57-43_348032_12595/logs/ray-data
2025-08-28 05:07:04,271	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_22_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbedImages) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns)->Write 5: 0.00 row [00:00, ? row/s]

[36m(MapWorker(MapBatches(EmbedImages)) pid=9215, ip=10.0.5.252)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-08-28 05:07:20,682	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_22_0 execution finished in 16.41 seconds
2025-08-28 05:07:20,747	INFO dataset.py:4871 -- Data sink Parquet finished. 2880 rows and 5.9MB data written.
2025-08-28 05:07:20,759	INFO logging.py:295 -- Registered dataset logger for dataset dataset_25_0
2025-08-28 05:07:20,774	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_25_0. Full logs are in /tmp/ray/session_2025-08-28_04-57-43_348032_12595/logs/ray-data
2025-08-28 05:07:20,775	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_25_0

Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbedImages) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns)->Write 5: 0.00 row [00:00, ? row/s]

path: string, new schema: image: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
path: string. This may lead to unexpected behavior.
path: string
class: string
label: int64, new schema: image: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
path: string
class: string
label: int64. This may lead to unexpected behavior.
[36m(MapWorker(MapBatches(EmbedImages)) pid=23307, ip=10.0.5.252)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.[32m [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
2025-08-28 05:07:33,

<div class="alert alert-block alert"> <b> Store often, save compute</b> 

Store the preprocessed data into shared cloud storage to:
- save a record of what this preprocessed data looks like
- avoid triggering the entire preprocessing for each batch the model processes
- avoid [`materialize`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.materialize.html) of the preprocessed data because you shouldn't force large data to fit in memory

## Model

Define the model -- a simple two layer neural net with Softmax layer to predict class probabilities. Notice that it's all just base PyTorch and nothing else.

In [None]:
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class ClassificationModel(torch.nn.Module):
    def __init__(self, embedding_dim, hidden_dim, dropout_p, num_classes):
        super().__init__()
        # Hyperparameters
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout_p
        self.num_classes = num_classes

        # Define layers
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.batch_norm = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_p)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, batch):
        z = self.fc1(batch["embedding"])
        z = self.batch_norm(z)
        z = self.relu(z)
        z = self.dropout(z)
        z = self.fc2(z)
        return z

    @torch.inference_mode()
    def predict(self, batch):
        z = self(batch)
        y_pred = torch.argmax(z, dim=1).cpu().numpy()
        return y_pred

    @torch.inference_mode()
    def predict_probabilities(self, batch):
        z = self(batch)
        y_probs = F.softmax(z, dim=1).cpu().numpy()
        return y_probs

    def save(self, dp):
        Path(dp).mkdir(parents=True, exist_ok=True)
        with open(Path(dp, "args.json"), "w") as fp:
            json.dump({
                "embedding_dim": self.embedding_dim,
                "hidden_dim": self.hidden_dim,
                "dropout_p": self.dropout_p,
                "num_classes": self.num_classes,
            }, fp, indent=4)
        torch.save(self.state_dict(), Path(dp, "model.pt"))

    @classmethod
    def load(cls, args_fp, state_dict_fp, device="cpu"):
        with open(args_fp, "r") as fp:
            model = cls(**json.load(fp))
        model.load_state_dict(torch.load(state_dict_fp, map_location=device))
        return model


In [None]:
# Initialize model.
num_classes = len(preprocessor.classes)
model = ClassificationModel(
    embedding_dim=512, 
    hidden_dim=256, 
    dropout_p=0.3, 
    num_classes=num_classes,
)
print (model)


ClassificationModel(
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=256, out_features=36, bias=True)
)


## Batching

Take a look at a sample batch of data and ensure that tensors have the proper data type.

In [None]:
from ray.train.torch import get_device


In [None]:
def collate_fn(batch):
    dtypes = {"embedding": torch.float32, "label": torch.int64}
    tensor_batch = {}
    for key in dtypes.keys():
        if key in batch:
            tensor_batch[key] = torch.as_tensor(
                batch[key],
                dtype=dtypes[key],
                device=get_device(),
            )
    return tensor_batch


In [None]:
# Sample batch
sample_batch = train_ds.take_batch(batch_size=3)
collate_fn(batch=sample_batch)


2025-08-28 05:07:34,380	INFO logging.py:295 -- Registered dataset logger for dataset dataset_27_0
2025-08-28 05:07:34,394	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_27_0. Full logs are in /tmp/ray/session_2025-08-28_04-57-43_348032_12595/logs/ray-data
2025-08-28 05:07:34,395	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_27_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> LimitOperator[limit=3]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbedImages) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- limit=3 6: 0.00 row [00:00, ? row/s]

[36m(MapWorker(MapBatches(EmbedImages)) pid=26114, ip=10.0.5.252)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-08-28 05:07:45,755	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_27_0 execution finished in 11.36 seconds
  tensor_batch[key] = torch.as_tensor(


{'embedding': tensor([[ 0.0245,  0.6505,  0.0627,  ...,  0.4001, -0.2721, -0.0673],
         [-0.2416,  0.2315,  0.0255,  ...,  0.4065,  0.2805, -0.1156],
         [-0.2301, -0.3628,  0.1086,  ...,  0.3038,  0.0543,  0.6214]]),
 'label': tensor([10, 29, 27])}

## Model registry

Create a model registry in [Anyscale user storage](https://docs.anyscale.com/configuration/storage/#user-storage) to save the model checkpoints to. Use OSS MLflow but you can easily [set up other experiment trackers](https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html) with Ray.

In [None]:
import shutil


In [None]:
model_registry = "/mnt/cluster_storage/mlflow/doggos"
if os.path.isdir(model_registry):
    shutil.rmtree(model_registry)  # clean up
os.makedirs(model_registry, exist_ok=True)


## Training

Define the training workload by specifying the:
- experiment and model parameters
- compute scaling configuration
- forward pass for batches of training and validation data
- train loop for each epoch of data and checkpointing

<img src="https://raw.githubusercontent.com/anyscale/multimodal-ai/refs/heads/main/images/trainer.png" width=700>

In [None]:
# Train loop config.
experiment_name = "doggos"
train_loop_config = {
    "model_registry": model_registry,
    "experiment_name": experiment_name,
    "embedding_dim": 512,
    "hidden_dim": 256,
    "dropout_p": 0.3,
    "lr": 1e-3,
    "lr_factor": 0.8,
    "lr_patience": 3,
    "num_epochs": 20,
    "batch_size": 256,
}


In [None]:
# Scaling config
num_workers = 4
scaling_config = ray.train.ScalingConfig(
    num_workers=num_workers,
    use_gpu=True,
    resources_per_worker={"CPU": 8, "GPU": 2},
    accelerator_type="T4",
)


In [None]:
import tempfile
import mlflow
import numpy as np
from ray.train.torch import TorchTrainer


In [None]:
def train_epoch(ds, batch_size, model, num_classes, loss_fn, optimizer):
    model.train()
    loss = 0.0
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    for i, batch in enumerate(ds_generator):
        optimizer.zero_grad()  # Reset gradients.
        z = model(batch)  # Forward pass.
        targets = F.one_hot(batch["label"], num_classes=num_classes).float()
        J = loss_fn(z, targets)  # Define loss.
        J.backward()  # Backward pass.
        optimizer.step()  # Update weights.
        loss += (J.detach().item() - loss) / (i + 1)  # Cumulative loss
    return loss


In [None]:
def eval_epoch(ds, batch_size, model, num_classes, loss_fn):
    model.eval()
    loss = 0.0
    y_trues, y_preds = [], []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    with torch.inference_mode():
        for i, batch in enumerate(ds_generator):
            z = model(batch)
            targets = F.one_hot(batch["label"], num_classes=num_classes).float()  # one-hot (for loss_fn)
            J = loss_fn(z, targets).item()
            loss += (J - loss) / (i + 1)
            y_trues.extend(batch["label"].cpu().numpy())
            y_preds.extend(torch.argmax(z, dim=1).cpu().numpy())
    return loss, np.vstack(y_trues), np.vstack(y_preds)


In [None]:
def train_loop_per_worker(config):
    # Hyperparameters.
    model_registry = config["model_registry"]
    experiment_name = config["experiment_name"]
    embedding_dim = config["embedding_dim"]
    hidden_dim = config["hidden_dim"]
    dropout_p = config["dropout_p"]
    lr = config["lr"]
    lr_factor = config["lr_factor"]
    lr_patience = config["lr_patience"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    num_classes = config["num_classes"]

    # Experiment tracking.
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.set_tracking_uri(f"file:{model_registry}")
        mlflow.set_experiment(experiment_name)
        mlflow.start_run()
        mlflow.log_params(config)

    # Datasets.
    train_ds = ray.train.get_dataset_shard("train")
    val_ds = ray.train.get_dataset_shard("val")

    # Model.
    model = ClassificationModel(
        embedding_dim=embedding_dim, 
        hidden_dim=hidden_dim, 
        dropout_p=dropout_p, 
        num_classes=num_classes,
    )
    model = ray.train.torch.prepare_model(model)

    # Training components.
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode="min", 
        factor=lr_factor, 
        patience=lr_patience,
    )

    # Training.
    best_val_loss = float("inf")
    for epoch in range(num_epochs):
        # Steps
        train_loss = train_epoch(train_ds, batch_size, model, num_classes, loss_fn, optimizer)
        val_loss, _, _ = eval_epoch(val_ds, batch_size, model, num_classes, loss_fn)
        scheduler.step(val_loss)

        # Checkpoint (metrics, preprocessor and model artifacts).
        with tempfile.TemporaryDirectory() as dp:
            model.module.save(dp=dp)
            metrics = dict(lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
            with open(os.path.join(dp, "class_to_label.json"), "w") as fp:
                json.dump(config["class_to_label"], fp, indent=4)
            if ray.train.get_context().get_world_rank() == 0:  # only on main worker 0
                mlflow.log_metrics(metrics, step=epoch)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    mlflow.log_artifacts(dp)

    # End experiment tracking.
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.end_run()


<div class="alert alert-block alert"> <b> Minimal change to your training code</b> 

Notice that there isn't much new Ray Train code on top of the base PyTorch code. You specified how you want to scale out the training workload, load the Ray datasets, and then checkpoint on the main worker node and that's it. See these guides ([PyTorch](https://docs.ray.io/en/latest/train/getting-started-pytorch.html), [PyTorch Lightning](https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html), [Hugging Face Transformers](https://docs.ray.io/en/latest/train/getting-started-transformers.html)) to see the minimal change in code needed to distribute your training workloads. See this extensive list of [Ray Train user guides](https://docs.ray.io/en/latest/train/user-guides.html).

In [None]:
# Load preprocessed datasets.
preprocessed_train_ds = ray.data.read_parquet(preprocessed_train_path)
preprocessed_val_ds = ray.data.read_parquet(preprocessed_val_path)




In [None]:
# Trainer.
train_loop_config["class_to_label"] = preprocessor.class_to_label
train_loop_config["num_classes"] = len(preprocessor.class_to_label)
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    datasets={"train": preprocessed_train_ds, "val": preprocessed_val_ds},
)


In [None]:
# Train.
results = trainer.fit()


## Ray Train

- automatically handles **multi-node, multi-GPU** setup with no manual SSH setup or hostfile configs. 
- define **per-worker fractional resource requirements**, for example, 2 CPUs and 0.5 GPU per worker.
- run on **heterogeneous machines** and scale flexibly, for example, CPU for preprocessing and GPU for training. 
- built-in **fault tolerance** with retry of failed workers and continue from last checkpoint.
- supports Data Parallel, Model Parallel, Parameter Server, and even custom strategies.
- [Ray Compiled graphs](https://docs.ray.io/en/latest/ray-core/compiled-graph/ray-compiled-graph.html) allow you to even define different parallelism for jointly optimizing multiple models like Megatron, DeepSpeed, etc., or only allow for one global setting.
- You can also use Torch DDP, FSPD, DeepSpeed, etc., under the hood.

🔥 [RayTurbo Train](https://docs.anyscale.com/rayturbo/rayturbo-train) offers even more improvement to the price-performance ratio, performance monitoring and more:
- **elastic training** to scale to a dynamic number of workers, continue training on fewer resources, even on spot instances.
- **purpose-built dashboard** designed to streamline the debugging of Ray Train workloads:
    - Monitoring: View the status of training runs and train workers.
    - Metrics: See insights on training throughput and training system operation time.
    - Profiling: Investigate bottlenecks, hangs, or errors from individual training worker processes.

<img src="https://raw.githubusercontent.com/anyscale/multimodal-ai/refs/heads/main/images/train_dashboard.png" width=700>

You can view experiment metrics and model artifacts in the model registry. You're using OSS MLflow so you can run the server by pointing to the model registry location:

```bash
mlflow server -h 0.0.0.0 -p 8080 --backend-store-uri /mnt/cluster_storage/mlflow/doggos
```

You can view the dashboard by going to the **Overview tab** > **Open Ports**. 

<img src="https://raw.githubusercontent.com/anyscale/multimodal-ai/refs/heads/main/images/mlflow.png" width=1000>

You also have the preceding Ray Dashboard and Train workload specific dashboards.

<img src="https://raw.githubusercontent.com/anyscale/multimodal-ai/refs/heads/main/images/train_metrics.png" width=1000>


In [None]:
# Sorted runs
mlflow.set_tracking_uri(f"file:{model_registry}")
sorted_runs = mlflow.search_runs(
    experiment_names=[experiment_name], 
    order_by=["metrics.val_loss ASC"])
best_run = sorted_runs.iloc[0]
best_run


run_id                                      d54aa07059384d139ea572123ae9409c
experiment_id                                             653138458592289747
status                                                              FINISHED
artifact_uri               file:///mnt/cluster_storage/mlflow/doggos/6531...
start_time                                  2025-08-28 05:10:15.049000+00:00
end_time                                    2025-08-28 05:10:33.936000+00:00
metrics.lr                                                             0.001
metrics.val_loss                                                    0.778273
metrics.train_loss                                                   0.39104
params.lr_factor                                                         0.8
params.hidden_dim                                                        256
params.embedding_dim                                                     512
params.dropout_p                                                         0.3

## Production Job

You can easily wrap the training workload as a production grade [Anyscale Job](https://docs.anyscale.com/platform/jobs/) ([API ref](https://docs.anyscale.com/reference/job-api/)).

**Note**: 
- This Job uses a `containerfile` to define dependencies, but you could easily use a pre-built image as well.
- You can specify the compute as a [compute config](https://docs.anyscale.com/configuration/compute-configuration/) or inline in a [job config](https://docs.anyscale.com/reference/job-api#job-cli) file.
- When you don't specify compute while launching from a workspace, this configuration defaults to the compute configuration of the workspace.

In [None]:
%%bash
# Production model training job
anyscale job submit -f /home/ray/default/configs/train_model.yaml


Output
(anyscale +0.8s) Submitting job with config JobConfig(name='train-image-model', image_uri='anyscale/ray:2.48.0-slim-py312-cu128', compute_config=None, env_vars=None, py_modules=['/home/ray/default/doggos'], py_executable=None, cloud=None, project=None, ray_version=None, job_queue_config=None).
(anyscale +3.0s) Uploading local dir '/home/ray/default' to cloud storage.
(anyscale +3.8s) Uploading local dir '/home/ray/default/doggos' to cloud storage.
(anyscale +4.9s) Job 'train-image-model' submitted, ID: 'prodjob_zfy5ak9a5masjb4vuidtxvxpqt'.
(anyscale +4.9s) View the job in the UI: https://console.anyscale.com/jobs/prodjob_zfy5ak9a5masjb4vuidtxvxpqt
(anyscale +4.9s) Use `--wait` to wait for the job to run and stream logs.


<img src="https://raw.githubusercontent.com/anyscale/multimodal-ai/refs/heads/main/images/train_job.png" width=1000>

## Evaluation

This tutorial concludes by evaluating the trained model on the test dataset. Evaluation is essentially the same as the batch inference workload where you apply the model on batches of data and then calculate metrics using the predictions versus true labels. Ray Data is hyper optimized for throughput so preserving order isn't a priority. But for evaluation, this approach is crucial. Achieve this approach by preserving the entire row and adding the predicted label as another column to each row.

In [None]:
from urllib.parse import urlparse
from sklearn.metrics import multilabel_confusion_matrix


In [None]:
class TorchPredictor:
    def __init__(self, preprocessor, model):
        self.preprocessor = preprocessor
        self.model = model
        self.model.eval()

    def __call__(self, batch, device="cuda"):
        self.model.to(device)
        batch["prediction"] = self.model.predict(collate_fn(batch))
        return batch

    def predict_probabilities(self, batch, device="cuda"):
        self.model.to(device)
        predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
        batch["probabilities"] = [
            {
                self.preprocessor.label_to_class[i]: float(prob)
                for i, prob in enumerate(probabilities)
            }
            for probabilities in predicted_probabilities
        ]
        return batch
    
    @classmethod
    def from_artifacts_dir(cls, artifacts_dir):
        with open(os.path.join(artifacts_dir, "class_to_label.json"), "r") as fp:
            class_to_label = json.load(fp)
        preprocessor = Preprocessor(class_to_label=class_to_label)
        model = ClassificationModel.load(
            args_fp=os.path.join(artifacts_dir, "args.json"), 
            state_dict_fp=os.path.join(artifacts_dir, "model.pt"),
        )
        return cls(preprocessor=preprocessor, model=model)


In [None]:
# Load and preproces eval dataset.
artifacts_dir = urlparse(best_run.artifact_uri).path
predictor = TorchPredictor.from_artifacts_dir(artifacts_dir=artifacts_dir)
test_ds = ray.data.read_images("s3://doggos-dataset/test", include_paths=True)
test_ds = test_ds.map(add_class)
test_ds = predictor.preprocessor.transform(ds=test_ds)


In [None]:
# y_pred (batch inference).
pred_ds = test_ds.map_batches(
    predictor,
    concurrency=4,
    batch_size=64,
    num_gpus=1,
    accelerator_type="T4",
)
pred_ds.take(1)


2025-08-28 05:10:42,369	INFO logging.py:295 -- Registered dataset logger for dataset dataset_40_0
2025-08-28 05:10:42,388	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_40_0. Full logs are in /tmp/ray/session_2025-08-28_04-57-43_348032_12595/logs/ray-data
2025-08-28 05:10:42,388	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_40_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbedImages) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- MapBatches(TorchPredictor) 6: 0.00 row [00:00, ? row/s]

- limit=1 7: 0.00 row [00:00, ? row/s]

[36m(MapWorker(MapBatches(EmbedImages)) pid=33395, ip=10.0.5.252)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
[36m(MapWorker(MapBatches(EmbedImages)) pid=6674, ip=10.0.5.20)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.[32m [repeated 3x across cluster][0m
2025-08-28 05:10:59,374	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_40_0 execution finished in 16.98 seconds


[{'path': 'doggos-dataset/test/basset/basset_10005.jpg',
  'class': 'basset',
  'label': 30,
  'embedding': array([ 8.86104554e-02, -5.89382686e-02,  1.15464866e-01,  2.15815112e-01,
         -3.43266308e-01, -3.35150540e-01,  1.48883224e-01, -1.02369718e-01,
         -1.69915810e-01,  4.34856862e-03,  2.41593361e-01,  1.79200619e-01,
          4.34402555e-01,  4.59785998e-01,  1.59284808e-02,  4.16959971e-01,
          5.20779848e-01,  1.86366066e-01, -3.43496174e-01, -4.00813907e-01,
         -1.15213782e-01, -3.04853529e-01,  1.77998394e-01,  1.82090014e-01,
         -3.56360346e-01, -2.30711952e-01,  1.69025257e-01,  3.78455579e-01,
          8.37044120e-02, -4.81875241e-02,  3.17967087e-01, -1.40099749e-01,
         -2.15949178e-01, -4.72761095e-01, -3.01893711e-01,  7.59940967e-02,
         -2.64865339e-01,  5.89084566e-01, -3.75831634e-01,  3.11807573e-01,
         -3.82964134e-01, -1.86417520e-01,  1.07007243e-01,  4.81416702e-01,
         -3.70819569e-01,  9.12090182e-01,  3.1

In [None]:
def batch_metric(batch):
    labels = batch["label"]
    preds = batch["prediction"]
    mcm = multilabel_confusion_matrix(labels, preds)
    tn, fp, fn, tp = [], [], [], []
    for i in range(mcm.shape[0]):
        tn.append(mcm[i, 0, 0])  # True negatives
        fp.append(mcm[i, 0, 1])  # False positives
        fn.append(mcm[i, 1, 0])  # False negatives
        tp.append(mcm[i, 1, 1])  # True positives
    return {"TN": tn, "FP": fp, "FN": fn, "TP": tp}


In [None]:
# Aggregated metrics after processing all batches.
metrics_ds = pred_ds.map_batches(batch_metric)
aggregate_metrics = metrics_ds.sum(["TN", "FP", "FN", "TP"])

# Aggregate the confusion matrix components across all batches.
tn = aggregate_metrics["sum(TN)"]
fp = aggregate_metrics["sum(FP)"]
fn = aggregate_metrics["sum(FN)"]
tp = aggregate_metrics["sum(TP)"]

# Calculate metrics.
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)


2025-08-28 05:10:59,627	INFO logging.py:295 -- Registered dataset logger for dataset dataset_43_0
2025-08-28 05:10:59,639	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_43_0. Full logs are in /tmp/ray/session_2025-08-28_04-57-43_348032_12595/logs/ray-data
2025-08-28 05:10:59,640	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_43_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> TaskPoolMapOperator[MapBatches(batch_metric)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbedImages) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- MapBatches(TorchPredictor) 6: 0.00 row [00:00, ? row/s]

- MapBatches(batch_metric) 7: 0.00 row [00:00, ? row/s]

- Aggregate 8: 0.00 row [00:00, ? row/s]

Sort Sample 9:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 12: 0.00 row [00:00, ? row/s]

[36m(MapWorker(MapBatches(EmbedImages)) pid=34103, ip=10.0.5.252)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
[36m(MapWorker(MapBatches(EmbedImages)) pid=40389, ip=10.0.5.252)[0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.[32m [repeated 3x across cluster][0m
2025-08-28 05:12:20,741	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_43_0 execution finished in 81.10 seconds


In [None]:
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1: {f1:.2f}")
print(f"Accuracy: {accuracy:.2f}")


Precision: 0.84
Recall: 0.84
F1: 0.84
Accuracy: 0.98


**🚨 Note**: Reset this notebook using the **"🔄 Restart"** button location at the notebook's menu bar. This way we can free up all the variables, utils, etc. used in this notebook.