In [1]:
import numpy as np
import os
import tempfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from typing import Dict
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
from sklearn.metrics import precision_score, recall_score, f1_score

RANDOM_SEED = 123

In [2]:
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

1
NVIDIA GeForce RTX 4080 SUPER


device(type='cuda', index=0)

In [3]:
def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/.data.lock")):
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

In [4]:
def load_test_data():
    # Load fake data for running a quick smoke-test.
    trainset = torchvision.datasets.FakeData(
        128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
    )
    testset = torchvision.datasets.FakeData(
        16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
    )
    return trainset, testset

In [5]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5) # Same as flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
def train_cifar(config): # Function API trainable
    net = Net(config["l1"], config["l2"])

    # device = "cpu"
    # if torch.cuda.is_available():
    #     device = "cuda:0"
    #     if torch.cuda.device_count() > 1:
    #         net = nn.DataParallel(net)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    # nn.module.parameters() returns an iterator over the module's parameters, it is typically passed to an optimizer.
    # SGD optimizer
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    # Load existing checkpoint through `get_checkpoint()` API.
    if train.get_checkpoint():
        loaded_checkpoint = train.get_checkpoint()
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            model_state, optimizer_state = torch.load(
                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
            )
            net.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

    if config["smoke_test"]:
        trainset, _ = load_test_data()
    else:
        trainset, _ = load_data()

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=0 if config["smoke_test"] else 8,
    )
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=0 if config["smoke_test"] else 8,
    )

    for epoch in range(10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad() # Because pytorch accumulates the gradients on subsequent backward passes, which is convenient for RNNs. But not for CNNs.

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward() # TODO: What does this do?
            optimizer.step() # TODO: What does this do?

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        val_preds = []
        val_labels = []
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                preds = torch.argmax(outputs, dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
                val_steps += 1

        precision = precision_score(val_labels, val_preds, average='weighted', zero_division=0)
        recall = recall_score(val_labels, val_preds, average='weighted', zero_division=0)
        f1 = f1_score(val_labels, val_preds, average='weighted', zero_division=0)
        
        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
        # in future iterations.
        # Note to save a file like checkpoint, you still need to put it under a directory
        # to construct a checkpoint.
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
            torch.save(
                (net.state_dict(), optimizer.state_dict()), path
            )
            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            #----- Report Score to Ray Tune -----
            train.report(
                {"loss": (val_loss / val_steps), "accuracy": correct / total, "precision": precision, "recall": recall, "f1": f1},
                checkpoint=checkpoint,
            )
    print("Finished Training")

In [7]:
def test_best_model(best_result, smoke_test=False):
    best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    best_trained_model.to(device)

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    model_state, optimizer_state = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    if smoke_test:
        _, testset = load_test_data()
    else:
        _, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2
    )

    correct = 0
    total = 0
    test_preds = []
    test_labels = []
    with torch.no_grad(): # We don't need to calculate the gradients when testing
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = best_trained_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(preds.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())

    precision = precision_score(test_labels, test_preds, average='weighted')
    recall = recall_score(test_labels, test_preds, average='weighted')
    f1 = f1_score(test_labels, test_preds, average='weighted')
    print("Best trial test set accuracy: {}".format(correct / total))
    print("Best trial test set precision: {}".format(precision))
    print("Best trial test set recall: {}".format(recall))
    print("Best trial test set f1: {}".format(f1))

In [8]:

def main(num_samples=10, max_num_epochs=10, gpus_per_trial=1, smoke_test=False):
    config = {
        "l1": tune.sample_from(lambda _: 2 ** np.random.randint(4, 9)),
        "l2": tune.sample_from(lambda _: 2 ** np.random.randint(4, 9)),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([8, 16, 32, 64]),
        "smoke_test": smoke_test,
    }
    scheduler = ASHAScheduler( #Schedule is used to stop searches early
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_cifar), # train_cifar is the trainable function
            resources={"cpu": 6, "gpu": gpus_per_trial} # Per trail resource
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        run_config=train.RunConfig(stop={"training_iteration": 5}),
        param_space=config,
    )
    results = tuner.fit()
    
    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(best_result.metrics["accuracy"]))
    print("Best trial final validation precision: {}".format(best_result.metrics["precision"]))
    print("Best trial final validation recall: {}".format(best_result.metrics["recall"]))

    # Testdata
    test_best_model(best_result, smoke_test=smoke_test) 
    return results

# The gpus_per_trail can be fractional, e.g. 0.5, just make sure GPU has enpugh memory.
# num_samples is the number of sample from hyperparameter space. = -1 means infinite samples until a stopping condition is met.
result =  main(num_samples=10, max_num_epochs=10, gpus_per_trial=0.25, smoke_test=True) 

0,1
Current time:,2024-09-10 21:55:24
Running for:,00:00:08.83
Memory:,5.6/31.2 GiB

Trial name,status,loc,batch_size,lr,iter,total time (s),loss,accuracy,precision
train_cifar_08c38_00000,TERMINATED,172.30.197.30:195293,32,0.0285523,2,0.551997,2.31532,0.115385,0.0133136
train_cifar_08c38_00001,TERMINATED,172.30.197.30:195294,16,0.000935676,1,0.451533,2.31188,0.0769231,0.00591716
train_cifar_08c38_00002,TERMINATED,172.30.197.30:195295,64,0.000122106,5,0.61636,2.27455,0.230769,0.0532544
train_cifar_08c38_00003,TERMINATED,172.30.197.30:195296,64,0.000793641,1,0.435413,2.32098,0.115385,0.0133136
train_cifar_08c38_00004,TERMINATED,172.30.197.30:195860,32,0.000156045,1,0.437824,2.33,0.0384615,0.00147929
train_cifar_08c38_00005,TERMINATED,172.30.197.30:195861,32,0.0292311,2,0.488307,2.2928,0.0769231,0.00591716
train_cifar_08c38_00006,TERMINATED,172.30.197.30:195862,64,0.00128328,4,0.495528,2.29078,0.0769231,0.00591716
train_cifar_08c38_00007,TERMINATED,172.30.197.30:195863,8,0.000428544,1,0.454755,2.3117,0.153846,0.0246154
train_cifar_08c38_00008,TERMINATED,172.30.197.30:196241,8,0.000158407,2,0.376354,2.31101,0.0384615,0.00147929
train_cifar_08c38_00009,TERMINATED,172.30.197.30:196242,32,0.000127323,2,0.349763,2.30854,0.115385,0.0133136


[36m(train_cifar pid=195295)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00002_2_batch_size=64,lr=0.0001_2024-09-10_21-55-15/checkpoint_000000)
[36m(train_cifar pid=195295)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00002_2_batch_size=64,lr=0.0001_2024-09-10_21-55-15/checkpoint_000001)
[36m(train_cifar pid=195295)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00002_2_batch_size=64,lr=0.0001_2024-09-10_21-55-15/checkpoint_000002)
[36m(train_cifar pid=196241)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00008_8_batch_size=8,lr=0.0002_2024-09-10_21-55-15/checkpoint_000001)[

Best trial config: {'l1': 64, 'l2': 16, 'lr': 0.00012210592916273645, 'batch_size': 64, 'smoke_test': True}
Best trial final validation loss: 2.274550199508667
Best trial final validation accuracy: 0.23076923076923078
Best trial final validation precision: 0.05325443786982249
Best trial final validation recall: 0.23076923076923078


  model_state, optimizer_state = torch.load(checkpoint_path)


Best trial test set accuracy: 0.0625
Best trial test set precision: 0.00390625
Best trial test set recall: 0.0625
Best trial test set f1: 0.007352941176470588


  _warn_prf(average, modifier, msg_start, len(result))


In [9]:
result

ResultGrid<[
  Result(
    metrics={'loss': 2.3153200149536133, 'accuracy': 0.11538461538461539, 'precision': 0.013313609467455622, 'recall': 0.11538461538461539, 'f1': 0.02387267904509284},
    path='/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00000_0_batch_size=32,lr=0.0286_2024-09-10_21-55-15',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00000_0_batch_size=32,lr=0.0286_2024-09-10_21-55-15/checkpoint_000001)
  ),
  Result(
    metrics={'loss': 2.3118847608566284, 'accuracy': 0.07692307692307693, 'precision': 0.00591715976331361, 'recall': 0.07692307692307693, 'f1': 0.01098901098901099},
    path='/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/train_cifar_08c38_00001_1_batch_size=16,lr=0.0009_2024-09-10_21-55-15',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/yuhang/ray_results/train_cifar_2024-09-10_21-55-13/

In [10]:
import wandb
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'ray_torch_example_SGD.ipynb'

notes = "Ray_Tune_Test"
run = wandb.init(project='cifar10', notes=notes, tags=['Ray_tune', 'cifar10', "Auto hyperparameter settings"])

Traceback (most recent call last):
  File "/home/yuhang/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/yuhang/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/yuhang/anaconda3/lib/python3.9/site-packages/wandb/__main__.py", line 1, in <module>
    from wandb.cli import cli
  File "/home/yuhang/anaconda3/lib/python3.9/site-packages/wandb/cli/cli.py", line 65, in <module>
    logging.basicConfig(
  File "/home/yuhang/anaconda3/lib/python3.9/logging/__init__.py", line 2003, in basicConfig
    h = FileHandler(filename, mode,
  File "/home/yuhang/anaconda3/lib/python3.9/logging/__init__.py", line 1146, in __init__
    StreamHandler.__init__(self, self._open())
  File "/home/yuhang/anaconda3/lib/python3.9/logging/__init__.py", line 1175, in _open
    return open(self.baseFilename, self.mode, encoding=self.encoding,
PermissionError: [Errno 13] Permission denied: 

ServiceStartProcessError: The wandb service process exited with 1. Ensure that `sys.executable` is a valid python interpreter. You can override it with the `_executable` setting or with the `WANDB__EXECUTABLE` environment variable.
{'command': ['/home/yuhang/anaconda3/bin/python', '-m', 'wandb', 'service', '--debug', '--port-filename', '/tmp/tmpcqoe1oec/port-193188.txt', '--pid', '193188', '--serve-sock'], 'sys_executable': '/home/yuhang/anaconda3/bin/python', 'which_python': '/home/yuhang/anaconda3/bin/python3', 'proc_out': '', 'proc_err': ''}