In [1]:
import csv
import datetime
import os   
import tempfile

from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import pandas as pd
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.models import resnet18
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

In [4]:
def train_loop_ray_train(config: dict):  

    criterion = CrossEntropyLoss()

    
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-5)

    
    global_batch_size = config["global_batch_size"]
    world_size = ray.train.get_context().get_world_size()
    batch_size = global_batch_size // world_size
    print(f"{world_size=}\n{batch_size=}")

    
    data_loader = build_data_loader_ray_train(batch_size=batch_size)

    
    for epoch in range(config["num_epochs"]):

        
        data_loader.sampler.set_epoch(epoch)

        
        for images, labels in data_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()

            
            loss.backward()
            optimizer.step()

        
        metrics = print_metrics_ray_train(loss, epoch)
        save_checkpoint_and_metrics_ray_train(model, metrics)

In [3]:
train_loop_config = {
    "num_epochs": 2, 
    "global_batch_size": 128
}

In [5]:
scaling_config = ScalingConfig(num_workers=2, use_gpu=False) # Set it to False if you don't have GPU

In [6]:
def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    model = ray.train.torch.prepare_model(model) # Instead of model = model.to("cuda")
    return model

In [7]:
def build_data_loader_ray_train(batch_size: int) -> torch.utils.data.DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

    # Automatically pass a DistributedSampler instance as a DataLoader sampler
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    return train_loader

In [8]:
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    world_rank = ray.train.get_context().get_world_rank() # report from all workers
    print(f"{metrics=} {world_rank=}")
    return metrics

In [None]:
def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        torch.save(
            model.module.state_dict(),  # note the `.module` to unwrap the DistributedDataParallel
            os.path.join(temp_checkpoint_dir, "model.pt"),
        )

        ray.train.report(
            metrics,
            checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
        )

In [None]:
storage_folder = None
storage_path = f"{storage_folder}/training/" 
run_config = RunConfig(storage_path=storage_path, name="distributed-mnist-resnet18")

In [None]:
trainer = TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)

In [None]:
result = trainer.fit()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

ckpt = result.checkpoint
with ckpt.as_directory() as ckpt_dir:
    model_path = os.path.join(ckpt_dir, "model.pt")
    loaded_model_ray_train = build_resnet18()
    state_dict = torch.load(model_path, map_location=torch.device(device), weights_only=True)
    loaded_model_ray_train.load_state_dict(state_dict)
    loaded_model_ray_train.to(device)
    loaded_model_ray_train.eval()

loaded_model_ray_train

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    normalized_img = Normalize((0.5,), (0.5,))(ToTensor()(img))
    normalized_img = normalized_img.to(device)

    # use loaded model to generate preds
    with torch.no_grad():        
        prediction = loaded_model_ray_train(normalized_img.unsqueeze(0)).argmax().cpu()

    figure.add_subplot(rows, cols, i)
    plt.title(f"label: {label}; pred: {int(prediction)}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")

In [None]:
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config={"num_epochs": 2, "global_batch_size": 128},
)
result = trainer.fit()
result.metrics_dataframe