In [1]:
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig,RunConfig
import torch.nn as nn
import torch
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 , ResNet18_Weights
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import Subset,DataLoader
import matplotlib.pyplot as plt
from torchmetrics.classification import Accuracy
from PIL import Image
from filelock import FileLock
from pathlib import Path

import tempfile
import os
import uuid

In [2]:
data = CIFAR10(root="../marimo_notebooks/data",download=True,train=False)
data

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ../marimo_notebooks/data
    Split: Test

In [3]:
# the data contains of a PIL image and the label
next(iter(data))

(<PIL.Image.Image image mode=RGB size=32x32>, 3)

In [4]:
class_to_idx = data.class_to_idx
class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [5]:
model = resnet18(weights=ResNet18_Weights)
model



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [33]:
def get_cifar_dataloader(batch_size):
    imagenet_transforms = ResNet18_Weights.IMAGENET1K_V1.transforms
    full_transform = Compose([ToTensor(),imagenet_transforms()])
    with FileLock(os.path.expanduser("~/cifar_data.lock")):
        train = CIFAR10(
            root="~/cifar_data",
            train=True,
            download=True,
            transform=full_transform,
        )
        valid = CIFAR10(
            root="~/cifar_data",
            train=False,
            download=True,
            transform=full_transform,
        )
    train_sub = Subset(train,indices=range(300))
    valid_sub = Subset(valid,indices=range(300))
    # dataloaders to get data in batches
    train_dataloader = DataLoader(train_sub, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_sub, batch_size=batch_size)

    return train_dataloader, valid_dataloader



In [34]:
class SimpleModel(nn.Module):
    def __init__(self, in_channels=3, hidden_features=128, out_features=10):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=2, padding="same"),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, 8)),  # force to fixed H × W
            nn.Flatten(),
            nn.Linear(16 * 8 * 8, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        )

    def forward(self, x):
        return self.block1(x)

In [43]:
def train_func(config):

    epochs = config["epochs"]
    batch_size = config["batch_size"]
    lr = config["lr"]
    
    
    # use detected device
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

    # metrics
    accuracy = Accuracy(task="multiclass", num_classes=config["num_classes"]).to(device)
    train_loss = 0.0
    train_acc = 0.0
    num_batches = 0.0

    checkpoint_path = Path("../marimo_notebooks/data/checkpoint")
    checkpoint_path.mkdir(parents=True,exist_ok=True)

    #device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # weights = ResNet18_Weights.IMAGENET1K_V1
    # model = resnet18(weights=weights)
    model = SimpleModel(in_channels=3,hidden_features=128,out_features=10)
    # for parameter in model.parameters():
    #     parameter.requires_grad = False
    # model.fc = nn.Linear(512,config["num_classes"],bias=True)
    
    model = ray.train.torch.prepare_model(model)
 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

    train_dataloader, valid_dataloader = get_cifar_dataloader(batch_size=batch_size)
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    valid_dataloader = ray.train.torch.prepare_data_loader(valid_dataloader)

    for epoch in range(epochs):
        # checking if training is scheduled in a distributed setting or not.
        if ray.train.get_context().get_world_size() > 1:
            train_dataloader.sampler.set_epoch(epoch)
        model.train()
        for idx, batch in enumerate(train_dataloader):
            x, y = batch[0], batch[1]
            y_preds = model(x)
            y_labels = y_preds.argmax(dim=1)
            loss = loss_fn(y_preds,y)
            acc = accuracy(y_labels,y)
            train_loss +=  loss.item()
            train_acc += acc.item()
        train_loss /=len(train_dataloader)
        train_acc /=len(train_dataloader)
        metrics = {"epoch":epoch,"train_loss":train_loss, "train_acc":train_acc}

        ray.train.report(metrics)
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)
        


In [44]:
global_batch_size = 10
num_workers = 8
use_gpu = False

train_config = {
    "lr": 1e-2,
    "epochs": 10,
    "num_classes": 10,
    "batch_size": global_batch_size // num_workers,
    "weight_decay": 1e-2
}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
    storage_path=str(Path("../marimo_notebooks/data/storage_path").resolve()), 
    name=f"ray_train_torch_run-{uuid.uuid4().hex}",    
)

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [45]:
result = trainer.fit()
print(f"Training result: {result}")

2025-08-02 13:14:28,171	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-08-02 13:14:28 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 9.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_12-51-11_285675_74959/artifacts/2025-08-02_13-14-28/ray_train_torch_run-a5e79510deaa4859836740ec8ef14248/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-08-02 13:14:33 (running for 00:00:05.14)
Using FIFO scheduling algorithm.
Logical resource usage: 9.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_12-51-11_285675_74959/artifacts/2025-08-02_13-14-28/ray_train_torch_run-a5e79510deaa4859836740ec8ef14248/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-08-02 13:14:38 (running for 00:00:10.20)
Using FIFO scheduling algorithm.
Logical resource usage: 9.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_12-51-11_285675_74959/artifacts/2025-08-02_13-14-28/ray_train_torch_run-a5e79510deaa4859836740ec8ef

2025-08-02 13:14:54,644	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/shoaib/code/anyscale-projects/anyscale-demo/marimo_notebooks/data/storage_path/ray_train_torch_run-a5e79510deaa4859836740ec8ef14248' in 0.0050s.
2025-08-02 13:14:54,646	INFO tune.py:1041 -- Total run time: 26.48 seconds (26.46 seconds for the tuning loop).


== Status ==
Current time: 2025-08-02 13:14:54 (running for 00:00:26.47)
Using FIFO scheduling algorithm.
Logical resource usage: 9.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_12-51-11_285675_74959/artifacts/2025-08-02_13-14-28/ray_train_torch_run-a5e79510deaa4859836740ec8ef14248/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


Training result: Result(
  metrics={'epoch': 9, 'train_loss': 2.3948231630960457, 'train_acc': 0.1095295858409676},
  path='/Users/shoaib/code/anyscale-projects/anyscale-demo/marimo_notebooks/data/storage_path/ray_train_torch_run-a5e79510deaa4859836740ec8ef14248/TorchTrainer_3ca76_00000_0_2025-08-02_13-14-28',
  filesystem='local',
  checkpoint=None
)


[36m(autoscaler +51m44s)[0m [autoscaler] Downscaling node i-04494db17bc8de61d (node IP: 100.78.239.19) due to node idle termination.
[36m(autoscaler +51m44s)[0m [autoscaler] Downscaling node i-0e47beae8cd6fcbc6 (node IP: 100.80.166.40) due to node idle termination.


In [10]:
ray.shutdown()