# Ray Tune - A Deeper Dive Using MNIST with PyTorch

A [previous notebook](02-Understanding-Hyperparameter-Tuning.ipynb) explained the concept of hyperparameter tuning/optimization (HPO) and walked through the basics of using [Ray Tune](https://ray.readthedocs.io/en/latest/tune.html), and another [notebook on Tune and Sklearn](03-Ray-Tune-with-Sklearn.ipynb) showed Tune's drop-in replacements for HPO.

Now we'll use another example to explore more of the Tune API features. We'll use the [MNIST](http://yann.lecun.com/exdb/mnist/) of hand-written digits and train a [PyTorch](https://pytorch.org/) model to recognize them.

In [1]:
import os 
from torchvision import datasets, transforms
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock

## PyTorch Hyperparameter Tuning

Our example will closely follow the code in the [PyTorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py). However, we will create an even simpler model than the one in the example, although you could try that model and compare its predictions.

Let's start by defining a few global variables for epoch and test sizes. Also define a data location.

In [2]:
EPOCH_SIZE = 512
TEST_SIZE = 256

DATA_ROOT = '../data/mnist'

The following class defines a convolutional neural network.

> **Tip:** Most of these code definitions can be found in `mnist.py`, too.

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

After creating that network, we can now create our data loaders for training and test data. These are just plain [PyTorch `DataLoaders`](https://pytorch.org/docs/1.1.0/data.html?highlight=dataloader#torch.utils.data.DataLoader) with two additions:

1. A `FileLock` is added to ensure that only one process downloads the data on each machine, just in case we have multiple workers per machine in our Ray cluster.
2. The root directory for the data can be specified and it will be created if it doesn't exist.

Otherwise, this code is identical to the [PyTorch example version](https://github.com/pytorch/examples/blob/master/mnist/main.py#L101).

In [4]:
def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # We add FileLock here because multiple workers on the same machine coulde try 
    # download the data. This would cause overwrites, since DataLoader is not threadsafe.
    # You wouldn't need this for single-process training.
    lock_file = f'{DATA_ROOT}/data.lock'
    import os
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT)
        
    with FileLock(os.path.expanduser(lock_file)):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=True, download=True, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=False, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)
    return train_loader, test_loader

Now we define our training and test functions. While the arguments are a bit switched up from the original PyTorch tutorial, the difference is inconsequential. The arguments are an optimizer, a model, the training data loader, and our device. Then we train the model.

In [5]:
def train(model, optimizer, train_loader, device=torch.device("cpu")):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

Similarly for our test model, we define a basic _average correct prediction_ metric that we will track. We could add more metrics, but we'll keep it simple.

In [6]:
def test(model, data_loader, device=torch.device("cpu")):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

Finally, we create a wrapper function for this particular model. In doing so all we need to do is specify the configuration for the model that we would like to train and the function will do the rest:

1. Retrieve the data with the loaders returned by `get_data_loaders()`
2. Create the `ConvNet` model
3. Optimize the model using _stochastic gradient descent_.

In [7]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        print(f"accuracy: {acc}")

### Single-Node Hyperparameter Tuning

Let's show what we might do if we performed hyperparameter tuning on a single machine. We would have to enumerate all the possibilities and either train them serially or use something like multiprocessing to train them in parallel. That setup takes a little bit of work so people often decide to train them serially, which is easiest, but requires the most time.

This is what we might do.

In [8]:
import itertools
conf = {
    "lr": [0.001, 0.01, 0.1],
    "momentum": [0.001, 0.01, 0.1, 0.9]
}

combinations = list(itertools.product(*conf.values()))
print(len(combinations))
combinations

12


[(0.001, 0.001),
 (0.001, 0.01),
 (0.001, 0.1),
 (0.001, 0.9),
 (0.01, 0.001),
 (0.01, 0.01),
 (0.01, 0.1),
 (0.01, 0.9),
 (0.1, 0.001),
 (0.1, 0.01),
 (0.1, 0.1),
 (0.1, 0.9)]

In [9]:
for lr, momentum in combinations:
    train_mnist({"lr":lr, "momentum":momentum})
    break # we'll stop this after one run and just use it for illustrative purposes

accuracy: 0.078125
accuracy: 0.1
accuracy: 0.090625
accuracy: 0.134375
accuracy: 0.090625
accuracy: 0.121875
accuracy: 0.11875
accuracy: 0.13125
accuracy: 0.11875
accuracy: 0.1875


### Distributed Hyperparameter Tuning with Ray Tune

Ray Tune makes it trivial to move this code from a single node to multiple nodes. Let's see how to use the code we've written with Ray Tune.

First, we set up Ray as before.

In [10]:
import ray
from ray import tune

In [11]:
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

setup_ray_cluster(
  num_worker_nodes=2,
  num_cpus_per_node=4,
  collect_log_to_path="/dbfs/path/to/ray_collected_logs"
)
ray.init()

2022-03-16 16:06:02,219	INFO services.py:1412 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8266[39m[22m


{'node_ip_address': '127.0.0.1',
 'raylet_ip_address': '127.0.0.1',
 'redis_address': None,
 'object_store_address': '/tmp/ray/session_2022-03-16_16-05-59_806072_59299/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-03-16_16-05-59_806072_59299/sockets/raylet',
 'webui_url': '127.0.0.1:8266',
 'session_dir': '/tmp/ray/session_2022-03-16_16-05-59_806072_59299',
 'metrics_export_port': 57268,
 'gcs_address': '127.0.0.1:61829',
 'address': '127.0.0.1:61829',
 'node_id': '5909d4d6039dfaf48d012ffbdded22d86fcb8439183926c4f64db43b'}

The first change is we'll perform a strict `grid_search` on our hyperparameters, like we used in the previous lesson. Our hyperparameters are the learning rate, `lr`, and the `momentum`.

In [12]:
config = {
    "lr": tune.grid_search([0.001, 0.01, 0.1]),
    "momentum": tune.grid_search([0.001, 0.01, 0.1, 0.9])
}

Next we modify our trainable, `train_mnist`, to use Tune's "reporting" logger:

In [13]:
def train_mnist(config):
    from ray.tune import report
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        # This sends the score to Tune.
        report(mean_accuracy=acc)

That's all that we need to change in order for Ray Tune to be able to parallelize our different hyperparameter combinations. 

When we execute a hyperparameter sweep, we perform an **experiment**. Each distinct combination of our different hyperparameters constitutes a single **trial**.

## Tune's Functional vs. Class API

In the above previous lesson, we used the **functional API**. This API is most convenient for quickly setting up experiments, but it provides less overall flexbility compared to the **class API** [`tune.Trainable`](https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable).

We'll try both, starting with the functional API.

We add a stopping criterion, `stop={"training_iteration": 20}`, so this will go reasonably quickly, while still producing good results. Consider removing this condition if you don't mind waiting longer and you want optimal results.

**Note**: Unlike the functional API, in which you the trainable can call a `tune.report()`, the class API method `cls.step()` can only return a value.

In [15]:
%%time
analysis_func = tune.run(train_mnist, config=config, stop={"training_iteration": 20},
                         verbose=1)

2022-03-16 16:08:03,116	INFO tune.py:639 -- Total run time: 9.76 seconds (9.62 seconds for the tuning loop).


CPU times: user 1.96 s, sys: 443 ms, total: 2.4 s
Wall time: 9.8 s


In [16]:
print("Best config: ", analysis_func.get_best_config(metric="mean_accuracy", mode="max"))

Best config:  {'lr': 0.1, 'momentum': 0.9}


In [17]:
analysis_func.dataframe().sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,time_total_s,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,config/lr,config/momentum,logdir
11,0.925,0.189417,False,,,10,e89bd_00011,bc3fdff939df4138b5fa96ced1456ed5,2022-03-16_16-08-02,1647472082,3.927645,59733,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,3.927645,0,10,0.1,0.9,/Users/jules/ray_results/train_mnist_2022-03-1...
8,0.9125,0.249835,False,,,10,e89bd_00008,a985488b7ece464d9f05caf2c0c82b5f,2022-03-16_16-08-02,1647472082,4.740164,59713,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,4.740164,0,10,0.1,0.1,/Users/jules/ray_results/train_mnist_2022-03-1...
2,0.8625,0.288845,False,,,10,e89bd_00002,5d1a6afd211d48768e5a45822aac051e,2022-03-16_16-08-02,1647472082,4.531698,59716,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,4.531698,0,10,0.1,0.001,/Users/jules/ray_results/train_mnist_2022-03-1...
10,0.8625,0.176915,False,,,10,e89bd_00010,3f62f2ae811c43afbe3d7ad9314c28de,2022-03-16_16-08-02,1647472082,5.008541,59709,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.008541,0,10,0.01,0.9,/Users/jules/ray_results/train_mnist_2022-03-1...
5,0.828125,0.241335,False,,,10,e89bd_00005,9a94566af3d840279d126911c843a47e,2022-03-16_16-08-02,1647472082,4.790111,59711,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,4.790111,0,10,0.1,0.01,/Users/jules/ray_results/train_mnist_2022-03-1...


In [18]:
analysis_func.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

Unnamed: 0,mean_accuracy,config/lr,config/momentum
11,0.925,0.1,0.9
8,0.9125,0.1,0.1
2,0.8625,0.1,0.001
10,0.8625,0.01,0.9
5,0.828125,0.1,0.01
7,0.759375,0.01,0.1
4,0.753125,0.01,0.01
9,0.74375,0.001,0.9
1,0.603125,0.01,0.001
6,0.23125,0.001,0.1


How long did it take? We'll compare this value with a different training run in the next lesson.

In [19]:
stats = analysis_func.stats()
secs = stats["timestamp"] - stats["start_time"]
print(f'{secs:7.2f} seconds, {secs/60.0:7.2f} minutes')

   0.17 seconds,    0.00 minutes


### Use Tune's Trainable Class API

As a subclass of `tune.Trainable`, Tune will create a Trainable object on a separate process (using the [Ray Actor API](https://docs.ray.io/en/latest/actors.html#actor-guide)).

 * setup function is invoked once training starts.
 * step is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.


In [20]:
class TrainMNIST(tune.Trainable):
    def setup(self, config):
        self.config = config
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["lr"])
    
    def step(self):
        train(self.model, self.optimizer, self.train_loader)
        acc = test(self.model, self.test_loader)
        return {"mean_accuracy": acc}

In [21]:
%%time
analysis = tune.run(
    TrainMNIST, 
    config=config,
    stop={"training_iteration": 20},
    verbose=1
)

2022-03-16 16:08:37,211	INFO tune.py:639 -- Total run time: 12.28 seconds (12.14 seconds for the tuning loop).


CPU times: user 2.43 s, sys: 529 ms, total: 2.96 s
Wall time: 12.3 s


In [22]:
print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max"))

Best config:  {'lr': 0.1, 'momentum': 0.01}


In [23]:
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
df.head()

Unnamed: 0,mean_accuracy,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,time_this_iter_s,time_total_s,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,config/lr,config/momentum,logdir
0,0.48125,True,,,20,fb6d9_00000,73230a5195ab42e1a161f7a9f65c0ce0,2022-03-16_16-08-36,1647472116,0.275424,5.515932,59813,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.515932,0,20,0.001,0.001,/Users/jules/ray_results/TrainMNIST_2022-03-16...
1,0.825,True,,,20,fb6d9_00001,bbfcd25b039a477bbf17147ce72686dd,2022-03-16_16-08-36,1647472116,0.274208,5.332833,59817,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.332833,0,20,0.01,0.001,/Users/jules/ray_results/TrainMNIST_2022-03-16...
2,0.890625,True,,,20,fb6d9_00002,7d2c226f1e76443e95da34646a215ce1,2022-03-16_16-08-36,1647472116,0.254596,5.547505,59816,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.547505,0,20,0.1,0.001,/Users/jules/ray_results/TrainMNIST_2022-03-16...
3,0.203125,True,,,20,fb6d9_00003,324c4693660c4ea3b8ae19782915ad90,2022-03-16_16-08-36,1647472116,0.272964,5.48914,59819,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.48914,0,20,0.001,0.01,/Users/jules/ray_results/TrainMNIST_2022-03-16...
4,0.803125,True,,,20,fb6d9_00004,fccbb38ac63743fb86723edb5bc485c2,2022-03-16_16-08-36,1647472116,0.253441,5.469929,59821,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.469929,0,20,0.01,0.01,/Users/jules/ray_results/TrainMNIST_2022-03-16...


In [24]:
analysis.dataframe().sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,time_this_iter_s,time_total_s,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,config/lr,config/momentum,logdir
5,0.93125,True,,,20,fb6d9_00005,e4276e94428d4bcba9df31e6501f5712,2022-03-16_16-08-36,1647472116,0.27469,5.549835,59820,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.549835,0,20,0.1,0.01,/Users/jules/ray_results/TrainMNIST_2022-03-16...
11,0.93125,True,,,20,fb6d9_00011,f6a6cb77c8d447c4a9e5cb6d9c3bf1c8,2022-03-16_16-08-36,1647472116,0.281194,5.506077,59834,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.506077,0,20,0.1,0.9,/Users/jules/ray_results/TrainMNIST_2022-03-16...
8,0.925,True,,,20,fb6d9_00008,7ae81f0ab99645138b1bd295816a1632,2022-03-16_16-08-36,1647472116,0.276365,5.507662,59815,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.507662,0,20,0.1,0.1,/Users/jules/ray_results/TrainMNIST_2022-03-16...
2,0.890625,True,,,20,fb6d9_00002,7d2c226f1e76443e95da34646a215ce1,2022-03-16_16-08-36,1647472116,0.254596,5.547505,59816,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.547505,0,20,0.1,0.001,/Users/jules/ray_results/TrainMNIST_2022-03-16...
1,0.825,True,,,20,fb6d9_00001,bbfcd25b039a477bbf17147ce72686dd,2022-03-16_16-08-36,1647472116,0.274208,5.332833,59817,Juless-MacBook-Pro-16-inch-2019,127.0.0.1,5.332833,0,20,0.01,0.001,/Users/jules/ray_results/TrainMNIST_2022-03-16...


It's easier to see what we want if project out the interesting columns:

In [25]:
analysis.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

Unnamed: 0,mean_accuracy,config/lr,config/momentum
5,0.93125,0.1,0.01
11,0.93125,0.1,0.9
8,0.925,0.1,0.1
2,0.890625,0.1,0.001
1,0.825,0.01,0.001
7,0.8125,0.01,0.1
4,0.803125,0.01,0.01
10,0.7875,0.01,0.9
0,0.48125,0.001,0.001
6,0.290625,0.001,0.1


How long did it take? We'll compare this value with a different training run in the next lesson.

In [26]:
stats = analysis.stats()
secs = stats["timestamp"] - stats["start_time"]
print(f'{secs:7.2f} seconds, {secs/60.0:7.2f} minutes')

  10.22 seconds,    0.17 minutes


The next lesson will explore optimization algorithms that speed up HPO.

In [27]:
shutdown_ray_cluster()