# Hyperparameter Tuning with `ray`

## Basic training script

Classify images from the CIFAR10 dataset. We look at the effect of batch size, learning rate, and momentum.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [2]:
DATA_PATH = "/home/tzhao/Workspace/hyperparameter-tuning-example/data"
NUM_EPOCHS = 4
DEVICE = "cpu" # Ray is very picky with resource usage, so we will use cpu to simplify configs

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Hyperparameter tuning with `ray`

In [4]:
from ray import tune
from ray.air import session

In [5]:
def train_ray_tune(config):

    ###########################################
    ##### RAY TUNE: Initialize parameters #####
    ###########################################
    
    batch_size=config["batch_size"]
    lr=config["lr"]
    momentum=config["momentum"]
    verbose = False
    download = False

    ###########################################
    ###########################################
    ###########################################

    # Load dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=download, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=download, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    # Load nn
    net = Net().to(DEVICE)

    # Load loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    # Train
    for epoch in range(NUM_EPOCHS):
        
        # Train single epoch
        for data in trainloader:
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Test
        test_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
                outputs = net(inputs)

                # Get test loss
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                # Get test acc
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc = correct / total

        if verbose:
            print(f"Epoch {epoch}: Test loss: {test_loss:.2f}")
            print(f"Epoch {epoch}: Test acc: {test_acc:.2f}")

        #########################################################
        ##### RAY TUNE: Report test loss and handle pruning #####
        #########################################################

        session.report({"test_acc": test_acc})

        #########################################################
        #########################################################
        #########################################################

    if verbose:
        print("Training completed.")
        print(f"Final test loss: {test_loss:.2f}")
        print(f"Final test acc: {test_acc:.2f}")
    
    return {"test_acc" : test_acc}

In [6]:
search_space = {
    "batch_size" : tune.randint(4,128),
    "lr": tune.loguniform(1e-5, 1e-1),
    "momentum": tune.uniform(0.1, 0.99),
}

tuner = tune.Tuner(
    train_ray_tune,
    tune_config=tune.TuneConfig(
        metric="test_acc",
        mode="max",
        num_samples=-1, # No limit to number of samples
        time_budget_s=300, # Limit time budget to 5 mins
    ),
    param_space=search_space,
)
results = tuner.fit()

0,1
Current time:,2023-07-21 08:28:48
Running for:,00:05:00.53
Memory:,32.7/125.6 GiB

Trial name,status,loc,batch_size,lr,momentum,iter,total time (s),test_acc
train_ray_tune_71ab8_00000,TERMINATED,10.47.0.57:1134061,5,0.000471392,0.559327,5.0,185.959,0.4464
train_ray_tune_71ab8_00001,TERMINATED,10.47.0.57:1134060,42,0.000360726,0.36892,5.0,107.06,0.1007
train_ray_tune_71ab8_00002,TERMINATED,10.47.0.57:1134062,124,5.23836e-05,0.931896,5.0,96.736,0.0761
train_ray_tune_71ab8_00003,TERMINATED,10.47.0.57:1134063,75,1.33158e-05,0.164153,5.0,99.7822,0.1
train_ray_tune_71ab8_00004,TERMINATED,10.47.0.57:1134064,124,0.00138568,0.862644,5.0,97.3587,0.2989
train_ray_tune_71ab8_00005,TERMINATED,10.47.0.57:1134065,11,0.0388903,0.193108,5.0,138.924,0.5291
train_ray_tune_71ab8_00006,TERMINATED,10.47.0.57:1134066,10,0.0566167,0.807891,5.0,141.457,0.2436
train_ray_tune_71ab8_00007,TERMINATED,10.47.0.57:1134067,88,0.0467266,0.547522,5.0,96.4963,0.5646
train_ray_tune_71ab8_00008,TERMINATED,10.47.0.57:1134068,105,0.00768337,0.730069,5.0,99.5277,0.4486
train_ray_tune_71ab8_00009,TERMINATED,10.47.0.57:1134069,103,0.0578928,0.815008,5.0,98.0258,0.5884


2023-07-21 08:28:48,841	INFO timeout.py:54 -- Reached timeout of 300 seconds. Stopping all trials.
2023-07-21 08:28:59,393	INFO tune.py:1148 -- Total run time: 310.97 seconds (300.46 seconds for the tuning loop).
[2m[36m(train_ray_tune pid=1134070)[0m Traceback (most recent call last):
[2m[36m(train_ray_tune pid=1134070)[0m   File "python/ray/_raylet.pyx", line 1364, in ray._raylet.execute_task.function_executor
[2m[36m(train_ray_tune pid=1134070)[0m   File "/home/tzhao/Workspace/hyperparameter-tuning-example/env_ray/lib/python3.8/site-packages/ray/_private/function_manager.py", line 726, in actor_method_executor
[2m[36m(train_ray_tune pid=1134070)[0m     return method(__ray_actor, *args, **kwargs)
[2m[36m(train_ray_tune pid=1134070)[0m   File "/home/tzhao/Workspace/hyperparameter-tuning-example/env_ray/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 464, in _resume_span
[2m[36m(train_ray_tune pid=1134070)[0m     return method(self, *_args, **_kw

In [7]:
print("Best config is:", results.get_best_result().config)
print("Best test_acc is:", results.get_best_result().metrics["test_acc"])

Best config is: {'batch_size': 5, 'lr': 0.009555123918268555, 'momentum': 0.18686923960135482}
Best test_acc is: 0.6143
