# Hyperparameter Tuning with `ray`

## Basic training script

Classify images from the CIFAR10 dataset. We look at the effect of batch size, learning rate, and momentum.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [2]:
DATA_PATH = "/home/tzhao/Workspace/hyperparameter-tuning-example/data"
NUM_EPOCHS = 4
DEVICE = "cpu" # Ray is very picky with resource usage, so we will use cpu to simplify configs

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train(batch_size, lr, momentum, verbose=True, download=True):
    # Load dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=download, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=download, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    # Load nn
    net = Net().to(DEVICE)

    # Load loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    # Train
    for epoch in range(NUM_EPOCHS):
        
        # Train single epoch
        for data in trainloader:
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Test
        test_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
                outputs = net(inputs)

                # Get test loss
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                # Get test acc
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc = correct / total

        if verbose:
            print(f"Epoch {epoch}: Test loss: {test_loss:.2f}")
            print(f"Epoch {epoch}: Test acc: {test_acc:.2f}")

    if verbose:
        print("Training completed.")
        print(f"Final test loss: {test_loss:.2f}")
        print(f"Final test acc: {test_acc:.2f}")
    
    return test_acc

In [4]:
# train(batch_size=4, lr=0.001, momentum=0.9, download=False)

In [5]:
# train(batch_size=128, lr=0.001, momentum=0.9, download=False)

Increasing batch size seems to give us worse accuracy, but is desirable because it runs faster. Can we get the bet of both worlds with hyperparameter tuning?

## Hyperparameter tuning with `ray`

In [6]:
from ray import tune
from ray.air import session

In [7]:
def train_ray_tune(config):

    ###########################################
    ##### RAY TUNE: Initialize parameters #####
    ###########################################
    
    batch_size=config["batch_size"]
    lr=config["lr"]
    momentum=config["momentum"]
    verbose = False
    download = False

    ###########################################
    ###########################################
    ###########################################

    # Load dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=download, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=download, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    # Load nn
    net = Net().to(DEVICE)

    # Load loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    # Train
    for epoch in range(NUM_EPOCHS):
        
        # Train single epoch
        for data in trainloader:
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Test
        test_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
                outputs = net(inputs)

                # Get test loss
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                # Get test acc
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc = correct / total

        if verbose:
            print(f"Epoch {epoch}: Test loss: {test_loss:.2f}")
            print(f"Epoch {epoch}: Test acc: {test_acc:.2f}")

        #########################################################
        ##### RAY TUNE: Report test loss and handle pruning #####
        #########################################################

        session.report({"test_acc": test_acc})

        #########################################################
        #########################################################
        #########################################################

    if verbose:
        print("Training completed.")
        print(f"Final test loss: {test_loss:.2f}")
        print(f"Final test acc: {test_acc:.2f}")
    
    return {"test_acc" : test_acc}

In [8]:
search_space = {
    "batch_size" : tune.randint(4,128),
    "lr": tune.loguniform(1e-5, 1e-1),
    "momentum": tune.uniform(0.1, 0.99),
}

tuner = tune.Tuner(
    train_ray_tune,
    tune_config=tune.TuneConfig(
        metric="test_acc",
        mode="max",
        num_samples=-1, # No limit to number of samples
        time_budget_s=300, # Limit time budget to 5 mins
    ),
    param_space=search_space,
)
results = tuner.fit()

0,1
Current time:,2023-07-21 05:46:18
Running for:,00:05:00.47
Memory:,29.8/125.6 GiB

Trial name,status,loc,batch_size,lr,momentum,iter,total time (s),test_acc
train_ray_tune_be24f_00000,TERMINATED,10.47.0.57:1031826,64,0.0249643,0.885086,5.0,90.683,0.605
train_ray_tune_be24f_00001,TERMINATED,10.47.0.57:1031825,8,6.02233e-05,0.182763,5.0,137.554,0.1
train_ray_tune_be24f_00002,TERMINATED,10.47.0.57:1031827,16,0.000367249,0.81679,5.0,112.615,0.3537
train_ray_tune_be24f_00003,TERMINATED,10.47.0.57:1031828,16,5.86507e-05,0.777631,5.0,112.165,0.1009
train_ray_tune_be24f_00004,TERMINATED,10.47.0.57:1031829,128,1.08189e-05,0.444186,5.0,88.754,0.1036
train_ray_tune_be24f_00005,TERMINATED,10.47.0.57:1031830,128,0.0131742,0.368587,5.0,88.7651,0.3909
train_ray_tune_be24f_00006,TERMINATED,10.47.0.57:1031831,16,0.0434757,0.852419,5.0,110.096,0.3316
train_ray_tune_be24f_00007,TERMINATED,10.47.0.57:1031832,4,0.00159883,0.565577,5.0,177.395,0.5833
train_ray_tune_be24f_00008,TERMINATED,10.47.0.57:1031833,32,0.0148393,0.658874,5.0,97.993,0.5942
train_ray_tune_be24f_00009,TERMINATED,10.47.0.57:1031834,4,2.89867e-05,0.700876,5.0,178.282,0.1006


2023-07-21 05:46:18,667	INFO timeout.py:54 -- Reached timeout of 300 seconds. Stopping all trials.
[2m[36m(train_ray_tune pid=1031832)[0m Traceback (most recent call last):
[2m[36m(train_ray_tune pid=1031832)[0m   File "python/ray/_raylet.pyx", line 1364, in ray._raylet.execute_task.function_executor
[2m[36m(train_ray_tune pid=1031832)[0m   File "/home/tzhao/Workspace/hyperparameter-tuning-example/env_ray/lib/python3.8/site-packages/ray/_private/function_manager.py", line 726, in actor_method_executor
[2m[36m(train_ray_tune pid=1031832)[0m     return method(__ray_actor, *args, **kwargs)
[2m[36m(train_ray_tune pid=1031832)[0m   File "/home/tzhao/Workspace/hyperparameter-tuning-example/env_ray/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 464, in _resume_span
[2m[36m(train_ray_tune pid=1031832)[0m     return method(self, *_args, **_kwargs)
[2m[36m(train_ray_tune pid=1031832)[0m   File "/home/tzhao/Workspace/hyperparameter-tuning-example/env_ra

Best config is: {'batch_size': 64, 'lr': 0.024964308064702182, 'momentum': 0.8850857682485122}


In [10]:
print("Best config is:", results.get_best_result().config)
print("Best test_acc is:", results.get_best_result().metrics["test_acc"])

Best config is: {'batch_size': 64, 'lr': 0.024964308064702182, 'momentum': 0.8850857682485122}
Best test_acc is: 0.605
