In [1]:
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig,RunConfig
import torch.nn as nn
import torch
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 , ResNet18_Weights
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import Subset,DataLoader
import matplotlib.pyplot as plt
from torchmetrics.classification import Accuracy
from PIL import Image
from filelock import FileLock
from pathlib import Path

import tempfile
import os
import uuid

In [2]:
data = CIFAR10(root="../marimo_notebooks/data",download=True,train=False)
data

100%|██████████| 170M/170M [00:06<00:00, 27.4MB/s] 


Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ../marimo_notebooks/data
    Split: Test

In [3]:
# the data contains of a PIL image and the label
next(iter(data))

(<PIL.Image.Image image mode=RGB size=32x32>, 3)

In [4]:
class_to_idx = data.class_to_idx
class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [5]:
model = resnet18(weights=ResNet18_Weights)
model



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 235MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
def get_cifar_dataloader(batch_size):
    imagenet_transforms = ResNet18_Weights.IMAGENET1K_V1.transforms
    full_transform = Compose([ToTensor(),imagenet_transforms()])
    with FileLock(os.path.expanduser("~/cifar_data.lock")):
        train = CIFAR10(
            root="~/cifar_data",
            train=True,
            download=True,
            transform=full_transform,
        )
        valid = CIFAR10(
            root="~/cifar_data",
            train=False,
            download=True,
            transform=full_transform,
        )

    # dataloaders to get data in batches
    train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid, batch_size=batch_size)

    return train_dataloader, valid_dataloader



In [17]:
def train_func(config):

    epochs = config["epochs"]
    batch_size = config["batch_size"]
    lr = config["lr"]
    
    
    # use detected device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # metrics
    accuracy = Accuracy(task="multiclass", num_classes=config["num_classes"]).to(device)
    train_loss = 0.0
    train_acc = 0.0
    num_batches = 0.0

    checkpoint_path = Path("../marimo_notebooks/data/checkpoint")
    checkpoint_path.mkdir(parents=True,exist_ok=True)

    #device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights)
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.fc = nn.Linear(512,config["num_classes"],bias=True)
    
    model = ray.train.torch.prepare_model(model)
 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

    train_dataloader, valid_dataloader = get_cifar_dataloader(batch_size=batch_size)
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    valid_dataloader = ray.train.torch.prepare_data_loader(valid_dataloader)

    for epoch in range(epochs):
        # checking if training is scheduled in a distributed setting or not.
        if ray.train.get_context().get_world_size() > 1:
            train_dataloader.sampler.set_epoch(epoch)
        model.train()
        for idx, batch in enumerate(train_dataloader):
            x, y = batch[0], batch[1]
            y_preds = model(x)
            y_labels = y_preds.argmax(dim=1)
            loss = loss_fn(y_preds,y)
            acc = accuracy(y_labels,y)
            train_loss +=  loss.item()
            train_acc += acc.item()
        train_loss /=len(train_dataloader)
        train_acc /=len(train_dataloader)
        metrics = {"epoch":epoch,"train_loss":train_loss, "train_acc":train_acc}

        ray.train.report(metrics)
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)
        


In [20]:
global_batch_size = 50
num_workers = 2
use_gpu = True

train_config = {
    "lr": 1e-2,
    "epochs": 2,
    "num_classes": 10,
    "batch_size": global_batch_size // num_workers,
    "weight_decay": 1e-2
}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
    storage_path=str(Path("../marimo_notebooks/data/storage_path").resolve()), 
    name=f"ray_train_torch_run-{uuid.uuid4().hex}",    
)

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [21]:
result = trainer.fit()
print(f"Training result: {result}")

2025-08-02 10:37:48,201	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-08-02 10:37:48 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 0/0 CPUs, 0/0 GPUs (0.0/1.0 anyscale/region:us-east-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/cpu_only:true, 0.0/1.0 anyscale/node-group:head)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 PENDING)




== Status ==
Current time: 2025-08-02 10:37:53 (running for 00:00:05.13)
Using FIFO scheduling algorithm.
Logical resource usage: 0/0 CPUs, 0/0 GPUs (0.0/1.0 anyscale/region:us-east-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/cpu_only:true, 0.0/1.0 anyscale/node-group:head)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-08-02 10:37:58 (running for 00:00:10.15)
Using FIFO scheduling algorithm.
Logical resource usage: 0/0 CPUs, 0/0 GPUs (0.0/1.0 anyscale/region:us-east-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/cpu_only:true, 0.0/1.0 anyscale/node-group:head)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-08-0

[36m(RayTrainWorker pid=3094, ip=100.78.239.19)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(TorchTrainer pid=3017, ip=100.78.239.19)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3017, ip=100.78.239.19)[0m - (node_id=2ab4bcd89526cab344b84201bff7bde50d65a5b35d2bb59cb64d7272, ip=100.78.239.19, pid=3094) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=3017, ip=100.78.239.19)[0m - (node_id=cf5abbe830417e514bcf76df544dde86c33f36eb0227b0a6f6a1e96c, ip=100.80.166.40, pid=2955) world_rank=1, local_rank=0, node_rank=1


== Status ==
Current time: 2025-08-02 10:39:58 (running for 00:02:10.65)
Using FIFO scheduling algorithm.
Logical resource usage: 1.0/16 CPUs, 2.0/2 GPUs (0.0/3.0 anyscale/provider:aws, 0.0/2.0 anyscale/node-group:1xT4:8CPU-32GB, 0.0/2.0 anyscale/accelerator_shape:1xT4, 0.0/3.0 anyscale/region:us-east-2, 0.0/2.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/cpu_only:true)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


[36m(RayTrainWorker pid=2955, ip=100.80.166.40)[0m Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]9.19)[0m 
 47%|████▋     | 21.0M/44.7M [00:00<00:00, 219MB/s]m 
 96%|█████████▌| 42.8M/44.7M [00:00<00:00, 224MB/s]m 
100%|██████████| 44.7M/44.7M [00:00<00:00, 247MB/s]m 
[36m(RayTrainWorker pid=2955, ip=100.80.166.40)[0m Moving model to device: cuda:0
[36m(RayTrainWorker pid=2955, ip=100.80.166.40)[0m Wrapping provided model in DistributedDataParallel.
100%|██████████| 44.7M/44.7M [00:00<00:00, 223MB/s]m 
 93%|█████████▎| 158M/170M [00:01<00:00, 82.8MB/s]0m 
100%|██████████| 170M/170M [00:02<00:00, 81.2MB/s]0m 


== Status ==
Current time: 2025-08-02 10:40:03 (running for 00:02:15.67)
Using FIFO scheduling algorithm.
Logical resource usage: 1.0/16 CPUs, 2.0/2 GPUs (0.0/3.0 anyscale/provider:aws, 0.0/2.0 anyscale/node-group:1xT4:8CPU-32GB, 0.0/2.0 anyscale/accelerator_shape:1xT4, 0.0/3.0 anyscale/region:us-east-2, 0.0/2.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/cpu_only:true)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




 94%|█████████▍| 160M/170M [00:03<00:00, 107MB/s][0m 
100%|██████████| 170M/170M [00:03<00:00, 51.2MB/s]0m 


== Status ==
Current time: 2025-08-02 10:40:08 (running for 00:02:20.69)
Using FIFO scheduling algorithm.
Logical resource usage: 1.0/16 CPUs, 2.0/2 GPUs (0.0/3.0 anyscale/region:us-east-2, 0.0/3.0 anyscale/provider:aws, 0.0/2.0 anyscale/node-group:1xT4:8CPU-32GB, 0.0/2.0 anyscale/accelerator_shape:1xT4, 0.0/2.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/cpu_only:true)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-08-02 10:40:13 (running for 00:02:25.71)
Using FIFO scheduling algorithm.
Logical resource usage: 1.0/16 CPUs, 2.0/2 GPUs (0.0/3.0 anyscale/region:us-east-2, 0.0/3.0 anyscale/provider:aws, 0.0/2.0 anyscale/node-group:1xT4:8CPU-32GB, 0.0/2.0 anyscale/accelerator_shape:1xT4, 0.0/2.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/cpu_only:true

2025-08-02 10:42:27,115	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/ray/default/anyscale-demo/marimo_notebooks/data/storage_path/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99' in 0.0038s.
2025-08-02 10:42:27,118	INFO tune.py:1041 -- Total run time: 278.92 seconds (278.91 seconds for the tuning loop).


== Status ==
Current time: 2025-08-02 10:42:27 (running for 00:04:38.91)
Using FIFO scheduling algorithm.
Logical resource usage: 1.0/16 CPUs, 2.0/2 GPUs (0.0/2.0 anyscale/node-group:1xT4:8CPU-32GB, 0.0/3.0 anyscale/provider:aws, 0.0/3.0 anyscale/region:us-east-2, 0.0/2.0 anyscale/accelerator_shape:1xT4, 0.0/2.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/cpu_only:true)
Result logdir: /tmp/ray/session_2025-08-02_09-49-46_005504_2484/artifacts/2025-08-02_10-37-48/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


Training result: Result(
  metrics={'epoch': 1, 'train_loss': 2.4506428616306786, 'train_acc': 0.07719911850573122},
  path='/home/ray/default/anyscale-demo/marimo_notebooks/data/storage_path/ray_train_torch_run-09552a7fb27843988fe015d8aff96b99/TorchTrainer_bb9a3_00000_0_2025-08-02_10-37-48',
  filesystem='local',
  checkpoint=None
)


[36m(autoscaler +51m44s)[0m [autoscaler] Downscaling node i-04494db17bc8de61d (node IP: 100.78.239.19) due to node idle termination.
[36m(autoscaler +51m44s)[0m [autoscaler] Downscaling node i-0e47beae8cd6fcbc6 (node IP: 100.80.166.40) due to node idle termination.
