# CIFAR-10 Classification with Hyperparameter Tuning using Ray Tune

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

import ray
from ray import tune
from ray.tune import Tuner, TuneConfig
from ray.tune.schedulers import ASHAScheduler
from ray.air import session


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
print("Using device:", device)


Using device: cuda


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])

train_val_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

num_train = 49000
num_val = 1000
train_set, val_set = torch.utils.data.random_split(train_val_set, [num_train, len(train_val_set)-num_train])

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)


In [4]:
class TwoLayerNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TwoLayerNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [5]:
def train_fc_model(config, train_loader=None, val_loader=None):
    model = TwoLayerNet(3 * 32 * 32, config["hidden_dim"], 10).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
    criterion = nn.CrossEntropyLoss()

    def l2_regularization(model):
        return sum(torch.norm(p, 2)**2 for name, p in model.named_parameters() if 'weight' in name)

    for _ in range(config["epochs"]):
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss += config["reg"] * l2_regularization(model)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    model.eval()
    correct = total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

    session.report({"val_accuracy": correct / total})


In [None]:
tuner_fc = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(train_fc_model, train_loader=train_loader, val_loader=val_loader),
        resources={"cpu": 5, "gpu": 0.5}
    ),
    param_space={
        "lr": tune.choice([0.01, 0.05, 0.1, 0.001, 0.0005]),
        "hidden_dim": tune.choice([128, 256, 512, 1000, 2000]),
        "reg": tune.choice([0.0001, 0.001, 0.01, 0.005, 0.0005, 0.05]),
        "epochs": 20
    },
    tune_config=tune.TuneConfig(
        num_samples=50,
        metric="val_accuracy",
        mode="max",
        scheduler=ASHAScheduler(),
        max_concurrent_trials=5
    )
)

results_fc = tuner_fc.fit()
best_fc = results_fc.get_best_result(metric="val_accuracy", mode="max")
print("Best FC config:", best_fc.config)
print("Best FC accuracy:", best_fc.metrics["val_accuracy"])


0,1
Current time:,2025-08-07 15:11:43
Running for:,00:00:07.62
Memory:,15.3/15.5 GiB

Trial name,status,loc,hidden_dim,lr,reg
train_fc_model_56737_00000,RUNNING,172.27.229.196:193115,1000,0.1,0.05
train_fc_model_56737_00001,RUNNING,172.27.229.196:193116,256,0.1,0.005
train_fc_model_56737_00002,PENDING,,256,0.1,0.005
train_fc_model_56737_00003,PENDING,,512,0.0005,0.005
train_fc_model_56737_00004,PENDING,,128,0.1,0.01


RuntimeError: Caught unexpected exception: Task was killed due to the node running low on memory.
Memory on the node (IP: 172.27.229.196, ID: 4c474c019b2ca202b6fd6001b8f3c655d11f437df1a62ba5dfd9093d) where the task (actor ID: 031542740f78cdba16d1ffe901000000, name=ImplicitFunc.__init__, pid=193116, memory used=0.46GB) was running was 15.48GB / 15.53GB (0.996813), which exceeds the memory usage threshold of 0.95. Ray killed this worker (ID: 9d7f413fa7f20a8b5720c5819209dec99de3169a43b31cfd50c44ff7) because it was the most recently scheduled task; to see more information about memory usage on this node, use `ray logs raylet.out -ip 172.27.229.196`. To see the logs of the worker, use `ray logs worker-9d7f413fa7f20a8b5720c5819209dec99de3169a43b31cfd50c44ff7*out -ip 172.27.229.196. Top 10 memory users:
PID	MEM(GB)	COMMAND
181774	1.11	/home/sraja/.vscode-server/bin/488a1f239235055e34e673291fb8d8c810886f81/node /home/sraja/.vscode-ser...
185449	1.01	/home/sraja/miniconda3/envs/pytorch/bin/python -m ipykernel_launcher --f=/run/user/1000/jupyter/runt...
44572	0.97	/home/sraja/miniconda3/envs/pytorch/bin/python -m ipykernel_launcher --f=/run/user/1000/jupyter/runt...
191085	0.79	/home/sraja/miniconda3/envs/pytorch/bin/python -m ipykernel_launcher --f=/run/user/1000/jupyter/runt...
185358	0.72	/home/sraja/miniconda3/envs/pytorch/bin/python -m ipykernel_launcher --f=/run/user/1000/jupyter/runt...
191017	0.71	/home/sraja/miniconda3/envs/pytorch/bin/python -m ipykernel_launcher --f=/run/user/1000/jupyter/runt...
488	0.63	/home/sraja/.vscode-server/bin/488a1f239235055e34e673291fb8d8c810886f81/node --dns-result-order=ipv4...
44291	0.50	/home/sraja/miniconda3/envs/pytorch/bin/python -m ipykernel_launcher --f=/run/user/1000/jupyter/runt...
48433	0.48	/home/sraja/miniconda3/envs/pytorch/lib/python3.12/site-packages/ray/core/src/ray/gcs/gcs_server --l...
193116	0.46	
Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. Set max_restarts and max_task_retries to enable retry when the task crashes due to OOM. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.

[33m(raylet)[0m [2025-08-07 15:12:34,109 E 191306 191306] (raylet) node_manager.cc:3041: 5 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: 4c474c019b2ca202b6fd6001b8f3c655d11f437df1a62ba5dfd9093d, IP: 172.27.229.196) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 172.27.229.196`
[33m(raylet)[0m 
[33m(raylet)[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.
