# Ray Tune - A Deeper Dive Using MNIST with PyTorch


Apopted from Anyscal unde Apache 2.0



In [None]:
print('NOTE: Intentionally crashing session to use the newly installed library.\n')

!pip uninstall -y pyarrow
!pip install ray

# A hack to force the runtime to restart, needed to include the above dependencies.
import os
os._exit(0)

NOTE: Intentionally crashing session to use the newly installed library.

Found existing installation: pyarrow 6.0.1
Uninstalling pyarrow-6.0.1:
  Successfully uninstalled pyarrow-6.0.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ray
  Downloading ray-1.12.1-cp37-cp37m-manylinux2014_x86_64.whl (53.2 MB)
[K     |████████████████████████████████| 53.2 MB 211 kB/s 
Collecting frozenlist
  Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
[K     |████████████████████████████████| 144 kB 66.6 MB/s 
Collecting aiosignal
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting virtualenv
  Downloading virtualenv-20.14.1-py2.py3-none-any.whl (8.8 MB)
[K     |████████████████████████████████| 8.8 MB 48.2 MB/s 
Collecting grpcio<=1.43.0,>=1.28.1
  Downloading grpcio-1.43.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.1

In [1]:
import os 
from torchvision import datasets, transforms
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock

## PyTorch Hyperparameter Tuning

Our example will closely follow the code in the [PyTorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py). However, we will create an even simpler model than the one in the example, although you could try that model and compare its predictions.

Let's start by defining a few global variables for epoch and test sizes. Also define a data location.

In [2]:
EPOCH_SIZE = 512
TEST_SIZE = 256

DATA_ROOT = 'data/mnist'

The following class defines a convolutional neural network.



In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

After creating that network, we can now create our data loaders for training and test data. These are just plain [PyTorch `DataLoaders`](https://pytorch.org/docs/1.1.0/data.html?highlight=dataloader#torch.utils.data.DataLoader) with two additions:

1. A `FileLock` is added to ensure that only one process downloads the data on each machine, just in case we have multiple workers per machine in our Ray cluster.
2. The root directory for the data can be specified and it will be created if it doesn't exist.

Otherwise, this code is identical to the [PyTorch example version](https://github.com/pytorch/examples/blob/master/mnist/main.py#L101).

In [4]:
def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # We add FileLock here because multiple workers on the same machine coulde try 
    # download the data. This would cause overwrites, since DataLoader is not threadsafe.
    # You wouldn't need this for single-process training.
    lock_file = f'{DATA_ROOT}/data.lock'
    import os
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT)
        
    with FileLock(os.path.expanduser(lock_file)):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=True, download=True, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=False, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)
    return train_loader, test_loader

Now we define our training and test functions. While the arguments are a bit switched up from the original PyTorch tutorial, the difference is inconsequential. The arguments are an optimizer, a model, the training data loader, and our device. Then we train the model.

In [5]:
def train(model, optimizer, train_loader, device=torch.device("cpu")):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

Similarly for our test model, we define a basic _average correct prediction_ metric that we will track. We could add more metrics, but we'll keep it simple.

In [6]:
def test(model, data_loader, device=torch.device("cpu")):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

Finally, we create a wrapper function for this particular model. In doing so all we need to do is specify the configuration for the model that we would like to train and the function will do the rest:

1. Retrieve the data with the loaders returned by `get_data_loaders()`
2. Create the `ConvNet` model
3. Optimize the model using _stochastic gradient descent_.

In [7]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        print(f"accuracy: {acc}")

### Single-Node Hyperparameter Tuning

Let's show what we might do if we performed hyperparameter tuning on a single machine. We would have to enumerate all the possibilities and either train them serially or use something like multiprocessing to train them in parallel. That setup takes a little bit of work so people often decide to train them serially, which is easiest, but requires the most time.

This is what we might do.

In [8]:
import itertools
conf = {
    "lr": [0.001, 0.01, 0.1],
    "momentum": [0.001, 0.01, 0.1, 0.9]
}

combinations = list(itertools.product(*conf.values()))
print(len(combinations))
combinations

12


[(0.001, 0.001),
 (0.001, 0.01),
 (0.001, 0.1),
 (0.001, 0.9),
 (0.01, 0.001),
 (0.01, 0.01),
 (0.01, 0.1),
 (0.01, 0.9),
 (0.1, 0.001),
 (0.1, 0.01),
 (0.1, 0.1),
 (0.1, 0.9)]

In [9]:
for lr, momentum in combinations:
    train_mnist({"lr":lr, "momentum":momentum})
    break # we'll stop this after one run and just use it for illustrative purposes

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw

accuracy: 0.140625
accuracy: 0.196875
accuracy: 0.19375
accuracy: 0.18125
accuracy: 0.23125
accuracy: 0.178125
accuracy: 0.234375
accuracy: 0.259375
accuracy: 0.26875
accuracy: 0.296875


### Distributed Hyperparameter Tuning with Ray Tune

Ray Tune makes it trivial to move this code from a single node to multiple nodes. Let's see how to use the code we've written with Ray Tune.

First, we set up Ray as before.

In [10]:
import ray
from ray import tune

In [11]:
ray.init(ignore_reinit_error=True)

RayContext(dashboard_url='', python_version='3.7.13', ray_version='1.12.1', ray_commit='4863e33856b54ccf8add5cbe75e41558850a1b75', address_info={'node_ip_address': '172.28.0.2', 'raylet_ip_address': '172.28.0.2', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-06-02_07-35-09_914084_131/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-06-02_07-35-09_914084_131/sockets/raylet', 'webui_url': '', 'session_dir': '/tmp/ray/session_2022-06-02_07-35-09_914084_131', 'metrics_export_port': 62552, 'gcs_address': '172.28.0.2:64268', 'address': '172.28.0.2:64268', 'node_id': 'eeae3e690c607dc264338965e95a5ced29e3fbe946276af6e7c9e6a8'})

The first change is we'll perform a strict `grid_search` on our hyperparameters, like we used in the previous lesson. Our hyperparameters are the learning rate, `lr`, and the `momentum`.

In [12]:
config = {
    "lr": tune.grid_search([0.001, 0.01, 0.1]),
    "momentum": tune.grid_search([0.001, 0.01, 0.1, 0.9])
}

Next we modify our trainable, `train_mnist`, to use Tune's "reporting" logger:

In [13]:
def train_mnist(config):
    from ray.tune import report
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        # This sends the score to Tune.
        report(mean_accuracy=acc)

That's all that we need to change in order for Ray Tune to be able to parallelize our different hyperparameter combinations. 

When we execute a hyperparameter sweep, we perform an **experiment**. Each distinct combination of our different hyperparameters constitutes a single **trial**.

## Tune's Functional vs. Class API

In the above previous lesson, we used the **functional API**. This API is most convenient for quickly setting up experiments, but it provides less overall flexbility compared to the **class API** [`tune.Trainable`](https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable).

We'll try both, starting with the functional API.

We add a stopping criterion, `stop={"training_iteration": 20}`, so this will go reasonably quickly, while still producing good results. Consider removing this condition if you don't mind waiting longer and you want optimal results.

**Note**: Unlike the functional API, in which you the trainable can call a `tune.report()`, the class API method `cls.step()` can only return a value.

In [14]:
%%time
analysis_func = tune.run(train_mnist, config=config, stop={"training_iteration": 20},
                         verbose=1)

2022-06-02 07:36:31,976	INFO tune.py:702 -- Total run time: 64.03 seconds (63.82 seconds for the tuning loop).


CPU times: user 1.99 s, sys: 310 ms, total: 2.3 s
Wall time: 1min 4s


In [15]:
print("Best config: ", analysis_func.get_best_config(metric="mean_accuracy", mode="max"))

Best config:  {'lr': 0.1, 'momentum': 0.9}


In [16]:
analysis_func.dataframe().sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,...,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time,config/lr,config/momentum,logdir
11,0.9,0.303608,False,,,10,92ca2_00011,8706745ce63c43ea87178790fa936a2f,2022-06-02_07-36-31,1654155391,...,1107,440e97f5cc6d,172.28.0.2,5.242462,0,10,0.007531,0.1,0.9,/root/ray_results/train_mnist_2022-06-02_07-35...
10,0.89375,0.439497,False,,,10,92ca2_00010,1e63728f84da4e42b14fc062b32c6c79,2022-06-02_07-36-31,1654155391,...,1086,440e97f5cc6d,172.28.0.2,6.270285,0,10,0.004675,0.01,0.9,/root/ray_results/train_mnist_2022-06-02_07-35...
5,0.88125,0.331679,False,,,10,92ca2_00005,745bc33c0478405589f3a106ab9d9be5,2022-06-02_07-36-01,1654155361,...,723,440e97f5cc6d,172.28.0.2,5.437617,0,10,0.004883,0.1,0.01,/root/ray_results/train_mnist_2022-06-02_07-35...
8,0.846875,0.403535,False,,,10,92ca2_00008,de449b5c22604364be2d4cad13d0932f,2022-06-02_07-36-20,1654155380,...,962,440e97f5cc6d,172.28.0.2,5.516682,0,10,0.004658,0.1,0.1,/root/ray_results/train_mnist_2022-06-02_07-35...
2,0.815625,0.431837,False,,,10,92ca2_00002,f979dfa10ff2489cbc5d58ebc471870d,2022-06-02_07-35-49,1654155349,...,564,440e97f5cc6d,172.28.0.2,8.282019,0,10,0.004937,0.1,0.001,/root/ray_results/train_mnist_2022-06-02_07-35...


In [17]:
analysis_func.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

Unnamed: 0,mean_accuracy,config/lr,config/momentum
11,0.9,0.1,0.9
10,0.89375,0.01,0.9
5,0.88125,0.1,0.01
8,0.846875,0.1,0.1
2,0.815625,0.1,0.001
4,0.803125,0.01,0.01
9,0.75625,0.001,0.9
7,0.74375,0.01,0.1
1,0.390625,0.01,0.001
3,0.23125,0.001,0.01


How long did it take? We'll compare this value with a different training run in the next lesson.

In [18]:
stats = analysis_func.stats()
secs = stats["timestamp"] - stats["start_time"]
print(f'{secs:7.2f} seconds, {secs/60.0:7.2f} minutes')

  61.49 seconds,    1.02 minutes


### Use Tune's Trainable Class API

As a subclass of `tune.Trainable`, Tune will create a Trainable object on a separate process (using the [Ray Actor API](https://docs.ray.io/en/latest/actors.html#actor-guide)).

 * setup function is invoked once training starts.
 * step is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.


In [19]:
class TrainMNIST(tune.Trainable):
    def setup(self, config):
        self.config = config
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["lr"])
    
    def step(self):
        train(self.model, self.optimizer, self.train_loader)
        acc = test(self.model, self.test_loader)
        return {"mean_accuracy": acc}

In [20]:
%%time
analysis = tune.run(
    TrainMNIST, 
    config=config,
    stop={"training_iteration": 20},
    verbose=1
)

2022-06-02 07:38:47,252	INFO tune.py:702 -- Total run time: 93.28 seconds (93.13 seconds for the tuning loop).


CPU times: user 2.57 s, sys: 380 ms, total: 2.95 s
Wall time: 1min 33s


In [21]:
print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max"))

Best config:  {'lr': 0.1, 'momentum': 0.9}


In [22]:
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
df.head()

Unnamed: 0,mean_accuracy,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,time_this_iter_s,...,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time,config/lr,config/momentum,logdir
0,0.13125,True,,,20,d1f71_00000,b330b9a3984f4d488cd03dbc4c78037a,2022-06-02_07-37-29,1654155449,0.414801,...,1259,440e97f5cc6d,172.28.0.2,8.407925,0,20,1.208078,0.001,0.001,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
1,0.78125,True,,,20,d1f71_00001,58799d43dd1e47ceba8090c10866bdb6,2022-06-02_07-37-31,1654155451,0.356045,...,1307,440e97f5cc6d,172.28.0.2,7.948132,0,20,1.014063,0.01,0.001,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
2,0.85625,True,,,20,d1f71_00002,eeb3ea4f7cea464ea7030cc036b45480,2022-06-02_07-37-45,1654155465,0.42426,...,1385,440e97f5cc6d,172.28.0.2,8.085383,0,20,1.187414,0.1,0.001,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
3,0.153125,True,,,20,d1f71_00003,5d99138e90e4434b97f15f8b7732428a,2022-06-02_07-37-45,1654155465,0.254512,...,1390,440e97f5cc6d,172.28.0.2,7.955377,0,20,2.009065,0.001,0.01,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
4,0.828125,True,,,20,d1f71_00004,ae2bc5ab2c7c4048bb6266a92bcba23f,2022-06-02_07-38-01,1654155481,0.392295,...,1534,440e97f5cc6d,172.28.0.2,8.284576,0,20,1.110083,0.01,0.01,/root/ray_results/TrainMNIST_2022-06-02_07-37-...


In [23]:
analysis.dataframe().sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,time_this_iter_s,...,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time,config/lr,config/momentum,logdir
11,0.9125,True,,,20,d1f71_00011,ce25c6d099be49b08c3cab45bec70c88,2022-06-02_07-38-47,1654155527,0.235299,...,1949,440e97f5cc6d,172.28.0.2,8.01466,0,20,1.453798,0.1,0.9,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
8,0.878125,True,,,20,d1f71_00008,e89b1f9af0224f649d99bcf11e622b8e,2022-06-02_07-38-31,1654155511,0.408687,...,1788,440e97f5cc6d,172.28.0.2,8.44405,0,20,2.147907,0.1,0.1,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
5,0.86875,True,,,20,d1f71_00005,de776b3e566a42a3b29937acd9053282,2022-06-02_07-38-02,1654155482,0.411252,...,1591,440e97f5cc6d,172.28.0.2,7.774504,0,20,0.976152,0.1,0.01,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
10,0.865625,True,,,20,d1f71_00010,750f659178674dc4ae41b510939c7a75,2022-06-02_07-38-46,1654155526,0.397708,...,1919,440e97f5cc6d,172.28.0.2,8.430455,0,20,1.960877,0.01,0.9,/root/ray_results/TrainMNIST_2022-06-02_07-37-...
2,0.85625,True,,,20,d1f71_00002,eeb3ea4f7cea464ea7030cc036b45480,2022-06-02_07-37-45,1654155465,0.42426,...,1385,440e97f5cc6d,172.28.0.2,8.085383,0,20,1.187414,0.1,0.001,/root/ray_results/TrainMNIST_2022-06-02_07-37-...


It's easier to see what we want if project out the interesting columns:

In [24]:
analysis.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

Unnamed: 0,mean_accuracy,config/lr,config/momentum
11,0.9125,0.1,0.9
8,0.878125,0.1,0.1
5,0.86875,0.1,0.01
10,0.865625,0.01,0.9
2,0.85625,0.1,0.001
4,0.828125,0.01,0.01
7,0.7875,0.01,0.1
1,0.78125,0.01,0.001
6,0.4,0.001,0.1
3,0.153125,0.001,0.01


How long did it take? We'll compare this value with a different training run in the next lesson.

In [25]:
stats = analysis.stats()
secs = stats["timestamp"] - stats["start_time"]
print(f'{secs:7.2f} seconds, {secs/60.0:7.2f} minutes')

  85.51 seconds,    1.43 minutes


The next lesson will explore optimization algorithms that speed up HPO.

In [26]:
ray.shutdown()  # "Undo ray.init()".