In [85]:
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig,RunConfig
import torch.nn as nn
import torch
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 , ResNet18_Weights
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import Subset,DataLoader
import matplotlib.pyplot as plt
from torchmetrics.classification import Accuracy
from PIL import Image
from filelock import FileLock

import tempfile
import os
import uuid

In [86]:
data = CIFAR10(root="../marimo_notebooks/data",download=True,train=False)
data

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ../marimo_notebooks/data
    Split: Test

In [87]:
# the data contains of a PIL image and the label
next(iter(data))

(<PIL.Image.Image image mode=RGB size=32x32>, 3)

In [88]:
class_to_idx = data.class_to_idx
class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [89]:
model = resnet18(weights=ResNet18_Weights)
model



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [94]:
def get_cifar_dataloader(batch_size):
    imagenet_transforms = ResNet18_Weights.IMAGENET1K_V1.transforms
    full_transform = Compose([ToTensor(),imagenet_transforms()])
    with FileLock(os.path.expanduser("~/cifar_data.lock")):
        train = CIFAR10(
            root="~/cifar_data",
            train=True,
            download=True,
            transform=full_transform,
        )
        valid = CIFAR10(
            root="~/cifar_data",
            train=False,
            download=True,
            transform=full_transform,
        )

    # dataloaders to get data in batches
    train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid, batch_size=batch_size)

    return train_dataloader, valid_dataloader

In [95]:
def train_func(config):

    epochs = config["epochs"]
    batch_size = config["batch_size"]
    lr = config["lr"]
    
    
    # use detected device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # metrics
    accuracy = Accuracy(task="multiclass", num_classes=config["num_classes"])
    train_loss = 0.0
    train_acc = 0.0
    num_batches = 0.0

    checkpoint_path = Path("../marimo_notebooks/data/checkpoint")
    checkpoint_path.mkdir(parents=True,exist_ok=True)

    #device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights)
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.fc = nn.Linear(512,config["num_classes"])
    
    model = ray.train.torch.prepare_model(model)
 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-2)

    train_dataloader, valid_dataloader = get_cifar_dataloader(batch_size=batch_size)
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    valid_dataloader = ray.train.torch.prepare_data_loader(valid_dataloader)

    for epoch in range(epochs):
        # checking if training is scheduled in a distributed setting or not.
        if ray.train.get_context().get_world_size() > 1:
            train_dataloader.sampler.set_epoch(epoch)
        model.train()
        for idx, batch in enumerate(train_dataloader):
            x, y = batch[0], batch[1]
            y_preds = model(x)
            y_labels = y_preds.argmax(dim=1)
            loss = loss_fn(y_preds,y)
            acc = accuracy(y_labels,y)
            train_loss +=  loss
            train_acc += acc
        train_loss/=len(train_dataloader)
        train_acc/=len(train_dataloader)
        print(f"train_loss:{train_loss} train_acc:{train_acc}")


In [96]:
global_batch_size = 100
num_workers = 5
use_gpu = False

train_config = {
    "lr": 1e-3,
    "epochs": 3,
    "num_classes": 10,
    "batch_size": global_batch_size // num_workers,
}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
    storage_path=str(Path("../marimo_notebooks/data/storage_path").resolve()), 
    name=f"train_run-{uuid.uuid4().hex}",
)

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [None]:
result = trainer.fit()
print(f"Training result: {result}")

2025-08-02 01:01:01,620	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-08-02 01:01:01 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 6.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_00-42-36_455163_84003/artifacts/2025-08-02_01-01-01/train_run-d2b626c3cdc94a9d9be870ab5e1edf78/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-08-02 01:01:06 (running for 00:00:05.14)
Using FIFO scheduling algorithm.
Logical resource usage: 6.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_00-42-36_455163_84003/artifacts/2025-08-02_01-01-01/train_run-d2b626c3cdc94a9d9be870ab5e1edf78/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-08-02 01:01:11 (running for 00:00:10.25)
Using FIFO scheduling algorithm.
Logical resource usage: 6.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_00-42-36_455163_84003/artifacts/2025-08-02_01-01-01/train_run-d2b626c3cdc94a9d9be870ab5e1edf78/driver_artifacts
Number 

2025-08-02 01:02:54,842	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/shoaib/code/anyscale-projects/anyscale-demo/marimo_notebooks/data/storage_path/train_run-d2b626c3cdc94a9d9be870ab5e1edf78' in 0.0038s.


== Status ==
Current time: 2025-08-02 01:02:54 (running for 00:01:53.22)
Using FIFO scheduling algorithm.
Logical resource usage: 6.0/10 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-08-02_00-42-36_455163_84003/artifacts/2025-08-02_01-01-01/train_run-d2b626c3cdc94a9d9be870ab5e1edf78/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


