# CIFAR-10 Classification with Hyperparameter Tuning using Ray Tune

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

import ray
from ray import tune
from ray.tune import Tuner, TuneConfig
from ray.tune.schedulers import ASHAScheduler
from ray.air import session


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
print("Using device:", device)


Using device: cuda


In [3]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])

train_val_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

num_train = 49000
num_val = 1000
train_set, val_set = torch.utils.data.random_split(train_val_set, [num_train, len(train_val_set)-num_train])

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)


In [4]:
from train_cnn import train_cnn_model

In [5]:
tuner_cnn = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(train_cnn_model, train_loader=train_loader, val_loader=val_loader),
        resources={"cpu": 20, "gpu": 1}
    ),
    param_space={
    "lr": tune.choice([
        1e-2,   # Fast learning, might overshoot
        5e-3,
        1e-3,   # Commonly works well
        5e-4,
        1e-4,   # Slower but safer convergence
    ]),
    "reg": tune.choice([
        1e-5,    # Very light regularization
        1e-4,    # Commonly used
        5e-4,
        1e-3,
        5e-3,    # Strong regularization (may underfit)
    ]),
        "epochs": 20
    },
    tune_config=tune.TuneConfig(
        num_samples=50,
        metric="val_accuracy",
        mode="max",
        scheduler=ASHAScheduler(),
        max_concurrent_trials=20
    )
)

results_cnn = tuner_cnn.fit()
best_cnn = results_cnn.get_best_result(metric="val_accuracy", mode="max")
print("Best CNN config:", best_cnn.config)
print("Best CNN accuracy:", best_cnn.metrics["val_accuracy"])


0,1
Current time:,2025-08-07 16:39:48
Running for:,00:34:43.19
Memory:,10.9/15.5 GiB

Trial name,status,loc,lr,reg,iter,total time (s),val_accuracy
train_cnn_model_cf3ca_00000,TERMINATED,172.27.229.196:237427,0.0005,0.0001,1,36.649,0.698
train_cnn_model_cf3ca_00001,TERMINATED,172.27.229.196:238182,0.0001,1e-05,1,37.9894,0.641
train_cnn_model_cf3ca_00002,TERMINATED,172.27.229.196:238778,0.0005,0.005,1,37.809,0.684
train_cnn_model_cf3ca_00003,TERMINATED,172.27.229.196:239370,0.005,0.0001,1,38.314,0.684
train_cnn_model_cf3ca_00004,TERMINATED,172.27.229.196:239959,0.01,1e-05,1,37.6684,0.694
train_cnn_model_cf3ca_00005,TERMINATED,172.27.229.196:240549,0.005,0.0001,1,36.8856,0.714
train_cnn_model_cf3ca_00006,TERMINATED,172.27.229.196:241138,0.01,1e-05,1,36.7393,0.69
train_cnn_model_cf3ca_00007,TERMINATED,172.27.229.196:241727,0.005,0.005,1,37.0745,0.576
train_cnn_model_cf3ca_00008,TERMINATED,172.27.229.196:242310,0.001,0.0005,1,36.8606,0.711
train_cnn_model_cf3ca_00009,TERMINATED,172.27.229.196:242899,0.01,0.005,1,37.5434,0.539


2025-08-07 16:39:48,303	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/sraja/ray_results/train_cnn_model_2025-08-07_16-05-02' in 0.0132s.
2025-08-07 16:39:48,311	INFO tune.py:1041 -- Total run time: 2083.22 seconds (2083.18 seconds for the tuning loop).


Best CNN config: {'lr': 0.005, 'reg': 0.0001, 'epochs': 20}
Best CNN accuracy: 0.721
