In [1]:
import ray
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
from ray.tune import CLIReporter
from ray.train import RunConfig
import torch
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
from utils import * 
from training_utils import *

data_full_path = "/scratch/zw2688/Court_Vision_Model_Dev/data/classification_dataset_groupby_env_split"      
  
def train_resnet(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 1)
    model.to(device)
    
    if config['normalize']:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                             std=[0.229, 0.224, 0.225])
    else:
        normalize = None
    train_loader, valid_loader, _ = get_data_loaders(config["batch_size"], config["img_size"],  data_full_path, normalize)
    
    if config["optimizer"] == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'], nesterov=config["nestrov"])
    elif config["optimizer"] == "adam":
        optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    else:
        optimizer = optim.RMSprop(model.parameters(), lr=config["lr"])
    
    # optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'], nesterov=config["nestrov"])
    
    criterion = nn.BCEWithLogitsLoss()
    while True:
        train_epoch(model, optimizer, criterion, train_loader, device)
        _, test_accuracy = test(model, criterion, valid_loader, device)
        ray.train.report(metrics = {"accuracy": test_accuracy})

In [2]:
batch_sizes = [64, 128, 256]
img_sizes = [224, 112, 96]
optimizers = ["rmsprop"]

# hyperband
hyperband_scheduler = HyperBandScheduler(
    time_attr='training_iteration',
    metric='accuracy',
    mode='max',
    max_t=120,
    reduction_factor=3
    )

hyperband_analysis = tune.run(
    train_resnet,
    name="tuning_cls_resnet18_augmented_rmsprop",
    stop={
        "accuracy": 0.93,
        "training_iteration": 100
    },
    resources_per_trial={
        "gpu": 0.1,
    },
    config = {
        "lr": tune.loguniform(8e-5, 5e-3),
        "batch_size": tune.grid_search(batch_sizes),
        "optimizer": tune.grid_search(optimizers),
        "img_size": tune.grid_search(img_sizes),
        #"nestrov": tune.grid_search([True, False]),
        #"momentum": tune.uniform(0.6, 0.95),
        "normalize": tune.grid_search([True, False]),
    },
    scheduler=hyperband_scheduler,
    num_samples=20,
    storage_path = "/scratch/zw2688/Court_Vision_Model_Dev/tune_results",
)


2024-01-05 11:02:07,674	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2024-01-05 11:02:09,691	INFO tune.py:220 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `tune.run(...)`.
2024-01-05 11:02:09,694	INFO tune.py:586 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


0,1
Current time:,2024-01-05 12:31:04
Running for:,01:28:54.73
Memory:,113.1/377.3 GiB

Trial name,status,loc,batch_size,img_size,lr,normalize,optimizer,iter,total time (s),accuracy
train_resnet_c80af_00002,RUNNING,10.32.35.157:500800,256,224,0.000689031,True,rmsprop,15,5203.87,0.509091
train_resnet_c80af_00004,RUNNING,10.32.35.157:500821,128,112,0.00013357,True,rmsprop,60,5284.6,0.869091
train_resnet_c80af_00010,RUNNING,10.32.35.157:500940,128,224,0.000854492,False,rmsprop,17,5175.56,0.56
train_resnet_c80af_00011,RUNNING,10.32.35.157:500942,256,224,0.00116204,False,rmsprop,16,5090.61,0.643636
train_resnet_c80af_00017,RUNNING,10.32.35.157:541896,256,96,0.000177545,False,rmsprop,9,557.5,0.883636
train_resnet_c80af_00020,RUNNING,10.32.35.157:542073,256,224,0.000116616,True,rmsprop,1,314.572,0.498182
train_resnet_c80af_00026,RUNNING,10.32.35.157:542558,256,96,0.000637777,True,rmsprop,7,434.765,0.687273
train_resnet_c80af_00005,PAUSED,10.32.35.157:500827,256,112,0.00191392,True,rmsprop,40,460.809,0.745455
train_resnet_c80af_00006,PAUSED,10.32.35.157:500829,64,96,0.000532234,True,rmsprop,40,2881.38,0.883636
train_resnet_c80af_00007,PAUSED,10.32.35.157:500841,128,96,0.00126159,True,rmsprop,40,2792.19,0.909091


Trial name,accuracy
train_resnet_c80af_00000,0.952727
train_resnet_c80af_00001,0.832727
train_resnet_c80af_00002,0.509091
train_resnet_c80af_00003,0.934545
train_resnet_c80af_00004,0.869091
train_resnet_c80af_00005,0.745455
train_resnet_c80af_00006,0.883636
train_resnet_c80af_00007,0.909091
train_resnet_c80af_00008,0.818182
train_resnet_c80af_00009,0.934545


2024-01-05 11:37:16,938	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=1; Now=1
2024-01-05 11:39:42,138	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=4; Now=1
2024-01-05 11:44:09,684	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=1; Now=1
2024-01-05 11:44:47,511	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=4; Now=1
2024-01-05 11:55:50,734	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=1; Now=1
2024-01-05 11:57:41,903	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=4; Now=1
2024-01-05 12:00:12,695	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=1; Now=1
2024-01-05 12:01:34,819	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=1; Now=1
2024-01-05 12:02:20,347	INFO hyperband.py:543 -- Restoring from a previous point in time. Previous=1; Now=1
2024-01-05 12:03:31,829	INFO