## 5. Hyperparameter tuning the PyTorch model using Ray Tune

The first step is to move in all the PyTorch code into a function that we can pass to the `trainable` argument of the `tune.run` function.

In [14]:
def train_pytorch(config): # we change the function so it accepts a config dictionary
    criterion = CrossEntropyLoss()

    model = resnet18()
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.to("cuda")

    optimizer = Adam(model.parameters(), lr=config["lr"])
    
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)
    data_loader = DataLoader(train_data, batch_size=config["batch_size"], shuffle=True, drop_last=True)

    for epoch in range(config["num_epochs"]):
        for images, labels in data_loader:
            images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Report the metrics using train.report instead of print
        train.report({"loss": loss.item()})

The second and third steps are the same as before. We define the tuner and run it by calling the fit method.

In [15]:
tuner = tune.Tuner(
    trainable=tune.with_resources(train_pytorch, {"gpu": 1}), # we will dedicate 1 GPU to each trial
    param_space={
        "num_epochs": 1,
        "batch_size": 128,
        "lr": tune.loguniform(1e-4, 1e-1),
    },
    tune_config=tune.TuneConfig(
        mode="min",
        metric="loss",
        num_samples=2,
        search_alg=tune.search.BasicVariantGenerator(),
        scheduler=tune.schedulers.FIFOScheduler(),
    ),
)

results = tuner.fit()

0,1
Current time:,2024-11-29 09:44:37
Running for:,00:00:25.36
Memory:,4.6/31.0 GiB

Trial name,status,loc,lr,iter,total time (s),loss
train_pytorch_7cf0c_00000,TERMINATED,10.0.35.222:38448,0.0582891,1,19.9064,0.264131
train_pytorch_7cf0c_00001,TERMINATED,10.0.20.35:10307,0.00787705,1,20.2805,0.0564279


  0%|          | 0/9912422 [00:00<?, ?it/s]
  1%|          | 98304/9912422 [00:00<00:11, 831953.25it/s]
  4%|▍         | 393216/9912422 [00:00<00:05, 1806326.67it/s]
 16%|█▌        | 1572864/9912422 [00:00<00:01, 5521900.47it/s]
 62%|██████▏   | 6127616/9912422 [00:00<00:00, 18478655.81it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 18409805.78it/s]


[36m(train_pytorch pid=38448)[0m Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
[36m(train_pytorch pid=38448)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 
[36m(train_pytorch pid=38448)[0m 


100%|██████████| 28881/28881 [00:00<00:00, 507557.46it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 16263535.27it/s]


[36m(train_pytorch pid=38448)[0m Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
[36m(train_pytorch pid=38448)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 
[36m(train_pytorch pid=38448)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 


100%|██████████| 1648877/1648877 [00:00<00:00, 4649632.58it/s]
100%|██████████| 28881/28881 [00:00<00:00, 490189.76it/s]


[36m(train_pytorch pid=38448)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 
[36m(train_pytorch pid=38448)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz[32m [repeated 12x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(train_pytorch pid=38448)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m Failed to download (trying next):[32m [repeated 6x across cluster][0m
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m HTTP Error 403: Forbidden[32m [repeated 6x across cluster][0m
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m 
[36m(train_pytorch pid=10307, ip=10.0.20.35)[0m Downloading http

100%|██████████| 4542/4542 [00:00<00:00, 3523960.19it/s][32m [repeated 7x across cluster][0m
100%|██████████| 1648877/1648877 [00:00<00:00, 4641500.47it/s][32m [repeated 10x across cluster][0m


2024-11-29 09:44:37,315	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/ray/ray_results/train_pytorch_2024-11-29_09-44-11' in 0.0033s.
2024-11-29 09:44:37,319	INFO tune.py:1041 -- Total run time: 25.37 seconds (25.36 seconds for the tuning loop).


Finally, we can get the best result and its configuration:

In [16]:
best_result = results.get_best_result()
best_result.config

{'num_epochs': 1, 'batch_size': 128, 'lr': 0.007877049646500664}