In [1]:
# Standard library
import os
import tempfile
import uuid
from pathlib import Path

# Third-party libraries
from filelock import FileLock
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchmetrics.classification import Accuracy

# torchvision (separate for clarity)
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms

# Ray imports
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig, CheckpointConfig

In [2]:
data = CIFAR10(root="../marimo_notebooks/data",download=True,train=False)
data

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ../marimo_notebooks/data
    Split: Test

In [3]:
# the data contains of a PIL image and the label
next(iter(data))

(<PIL.Image.Image image mode=RGB size=32x32>, 3)

In [6]:
class_to_idx = data.class_to_idx
class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [7]:
weights = ResNet18_Weights.IMAGENET1K_V1

In [8]:
model = resnet18(weights=ResNet18_Weights)
model



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
def create_transforms():
    train_transform = transforms.Compose([
        transforms.Resize((150,150)),
        transforms.RandomCrop((128,128)),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.ColorJitter(brightness=0.1,contrast=0.1,saturation=0.1,hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) #-> [0,1] -> [-1,1] 
        
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((150,150)),
        transforms.CenterCrop((128,128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) #-> [0,1] -> [-1,1] 
        
    ])
    return train_transform, valid_transform

In [10]:
def get_cifar_dataloader(batch_size):
    train_transform, valid_transform  = create_transforms()
    with FileLock(os.path.expanduser("~/cifar_data.lock")):
        train = CIFAR10(
            root="~/cifar_data",
            train=True,
            download=True,
            transform=train_transform,
        )
        valid = CIFAR10(
            root="~/cifar_data",
            train=False,
            download=True,
            transform=valid_transform,
        )
    train_sub = Subset(train,indices=range(1000))
    valid_sub = Subset(valid,indices=range(1000))
    # dataloaders to get data in batches
    train_dataloader = DataLoader(train_sub, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_sub, batch_size=batch_size)

    return train_dataloader, valid_dataloader



In [11]:
sample_dataloader = next(iter(get_cifar_dataloader(3)))
single_batch =  next(iter(sample_dataloader))

In [12]:
single_batch[1].shape

torch.Size([3])

In [13]:
single_batch[0].shape

torch.Size([3, 3, 128, 128])

In [None]:
def train_func(config):

    epochs = config["epochs"]
    batch_size = config["batch_size"]
    lr = config["lr"]
    
    
    # use detected device
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

    # metrics
    accuracy = Accuracy(task="multiclass", num_classes=config["num_classes"]).to(device)
    
    device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weights = ResNet18_Weights.IMAGENET1K_V1

    # freezing the all weights except the last one
    model = resnet18(weights=weights)
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.fc = nn.Linear(512,config["num_classes"],bias=True)
    
    model = ray.train.torch.prepare_model(model)
 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

    train_dataloader, valid_dataloader = get_cifar_dataloader(batch_size=batch_size)
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    valid_dataloader = ray.train.torch.prepare_data_loader(valid_dataloader)

    for epoch in range(epochs):
        # checking if training is scheduled in a distributed setting or not.
        if ray.train.get_context().get_world_size() > 1:
            train_dataloader.sampler.set_epoch(epoch)
        train_loss = 0.0
        train_acc = 0.0
        num_total = 0.0
        num_correct = 0.0
        #num_batches = 0.0
        model.train()
        for idx, batch in enumerate(train_dataloader):
            x, y = batch[0], batch[1]
            y_preds = model(x)
            y_labels = y_preds.argmax(dim=1)
            loss = loss_fn(y_preds,y)
            acc = accuracy(y_labels,y)
            train_loss +=  loss.item()
            train_acc += acc.item()
        train_loss /=len(train_dataloader)
        train_acc /=len(train_dataloader)
        metrics = {"epoch":epoch,"train_loss":train_loss, "train_acc":train_acc}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )

        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

In [None]:
global_batch_size = 300
num_workers = 5
use_gpu = True

train_config = {
    "lr": 0.01,
    "epochs": 3,
    "num_classes": 10,
    "batch_size": global_batch_size // num_workers,
    "weight_decay": 0.02
}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
    storage_path="/mnt/cluster_storage", # we could use s3 as well
    name=f"ray_train_torch_run-{uuid.uuid4().hex}",
    checkpoint_config = CheckpointConfig(num_to_keep=1,
    checkpoint_score_attribute="train_acc",
    checkpoint_score_order="max",) 
)

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [None]:
result = trainer.fit()
print(f"Training result: {result}")

In [None]:
ray.shutdown()