# Ray Tune - A Deeper Dive Using MNIST with PyTorch

© 2019-2022, Anyscale. All Rights Reserved

![Anyscale Academy](../images/AnyscaleAcademyLogo.png)

A [previous notebook](02-Understanding-Hyperparameter-Tuning.ipynb) explained the concept of hyperparameter tuning/optimization (HPO) and walked through the basics of using [Ray Tune](https://ray.readthedocs.io/en/latest/tune.html), and another [notebook on Tune and Sklearn](03-Ray-Tune-with-Sklearn.ipynb) showed Tune's drop-in replacements for HPO.

Now we'll use another example to explore more of the Tune API features. We'll use the [MNIST](http://yann.lecun.com/exdb/mnist/) of hand-written digits and train a [PyTorch](https://pytorch.org/) model to recognize them.

In [2]:
import os 
from torchvision import datasets, transforms
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock

## PyTorch Hyperparameter Tuning

Our example will closely follow the code in the [PyTorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py). However, we will create an even simpler model than the one in the example, although you could try that model and compare its predictions.

Let's start by defining a few global variables for epoch and test sizes. Also define a data location.

In [3]:
EPOCH_SIZE = 512
TEST_SIZE = 256

DATA_ROOT = '../data/mnist'

The following class defines a convolutional neural network.

> **Tip:** Most of these code definitions can be found in `mnist.py`, too.

In [4]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

After creating that network, we can now create our data loaders for training and test data. These are just plain [PyTorch `DataLoaders`](https://pytorch.org/docs/1.1.0/data.html?highlight=dataloader#torch.utils.data.DataLoader) with two additions:

1. A `FileLock` is added to ensure that only one process downloads the data on each machine, just in case we have multiple workers per machine in our Ray cluster.
2. The root directory for the data can be specified and it will be created if it doesn't exist.

Otherwise, this code is identical to the [PyTorch example version](https://github.com/pytorch/examples/blob/master/mnist/main.py#L101).

In [5]:
def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # We add FileLock here because multiple workers on the same machine coulde try 
    # download the data. This would cause overwrites, since DataLoader is not threadsafe.
    # You wouldn't need this for single-process training.
    lock_file = f'{DATA_ROOT}/data.lock'
    import os
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT)
        
    with FileLock(os.path.expanduser(lock_file)):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=True, download=True, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=False, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)
    return train_loader, test_loader

Now we define our training and test functions. While the arguments are a bit switched up from the original PyTorch tutorial, the difference is inconsequential. The arguments are an optimizer, a model, the training data loader, and our device. Then we train the model.

In [6]:
def train(model, optimizer, train_loader, device=torch.device("cpu")):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

Similarly for our test model, we define a basic _average correct prediction_ metric that we will track. We could add more metrics, but we'll keep it simple.

In [7]:
def test(model, data_loader, device=torch.device("cpu")):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

Finally, we create a wrapper function for this particular model. In doing so all we need to do is specify the configuration for the model that we would like to train and the function will do the rest:

1. Retrieve the data with the loaders returned by `get_data_loaders()`
2. Create the `ConvNet` model
3. Optimize the model using _stochastic gradient descent_.

In [8]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        print(f"accuracy: {acc}")

### Single-Node Hyperparameter Tuning

Let's show what we might do if we performed hyperparameter tuning on a single machine. We would have to enumerate all the possibilities and either train them serially or use something like multiprocessing to train them in parallel. That setup takes a little bit of work so people often decide to train them serially, which is easiest, but requires the most time.

This is what we might do.

In [9]:
import itertools
conf = {
    "lr": [0.001, 0.01, 0.1],
    "momentum": [0.001, 0.01, 0.1, 0.9]
}

combinations = list(itertools.product(*conf.values()))
print(len(combinations))
combinations

12


[(0.001, 0.001),
 (0.001, 0.01),
 (0.001, 0.1),
 (0.001, 0.9),
 (0.01, 0.001),
 (0.01, 0.01),
 (0.01, 0.1),
 (0.01, 0.9),
 (0.1, 0.001),
 (0.1, 0.01),
 (0.1, 0.1),
 (0.1, 0.9)]

In [10]:
for lr, momentum in combinations:
    train_mnist({"lr":lr, "momentum":momentum})
    break # we'll stop this after one run and just use it for illustrative purposes

100%|██████████| 9.91M/9.91M [00:00<00:00, 26.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.82MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 12.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]


accuracy: 0.046875
accuracy: 0.04375
accuracy: 0.040625
accuracy: 0.075
accuracy: 0.071875
accuracy: 0.059375
accuracy: 0.115625
accuracy: 0.084375
accuracy: 0.06875
accuracy: 0.1


### Distributed Hyperparameter Tuning with Ray Tune

Ray Tune makes it trivial to move this code from a single node to multiple nodes. Let's see how to use the code we've written with Ray Tune.

First, we set up Ray as before.

In [11]:
import ray
from ray import tune

In [12]:
ray.init(ignore_reinit_error=True)

2026-01-18 17:13:34,306	INFO worker.py:2007 -- Started a local Ray instance.


0,1
Python version:,3.11.9
Ray version:,2.53.0


The first change is we'll perform a strict `grid_search` on our hyperparameters, like we used in the previous lesson. Our hyperparameters are the learning rate, `lr`, and the `momentum`.

In [13]:
config = {
    "lr": tune.grid_search([0.001, 0.01, 0.1]),
    "momentum": tune.grid_search([0.001, 0.01, 0.1, 0.9])
}

Next we modify our trainable, `train_mnist`, to use Tune's "reporting" logger:

In [18]:
def train_mnist(config):
    from ray.tune import report
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        # This sends the score to Tune.
        report({"mean_accuracy": acc})

That's all that we need to change in order for Ray Tune to be able to parallelize our different hyperparameter combinations. 

When we execute a hyperparameter sweep, we perform an **experiment**. Each distinct combination of our different hyperparameters constitutes a single **trial**.

## Tune's Functional vs. Class API

In the above previous lesson, we used the **functional API**. This API is most convenient for quickly setting up experiments, but it provides less overall flexbility compared to the **class API** [`tune.Trainable`](https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable).

We'll try both, starting with the functional API.

We add a stopping criterion, `stop={"training_iteration": 20}`, so this will go reasonably quickly, while still producing good results. Consider removing this condition if you don't mind waiting longer and you want optimal results.

**Note**: Unlike the functional API, in which you the trainable can call a `tune.report()`, the class API method `cls.step()` can only return a value.

In [19]:
%%time
analysis_func = tune.run(train_mnist, config=config, stop={"training_iteration": 20},
                         verbose=1)
print("Done")

0,1
Current time:,2026-01-18 17:18:29
Running for:,00:00:28.99
Memory:,14.0/15.7 GiB

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
train_mnist_8c74b_00000,TERMINATED,127.0.0.1:1976,0.001,0.001,0.128125,10,5.43428
train_mnist_8c74b_00001,TERMINATED,127.0.0.1:15476,0.01,0.001,0.753125,10,6.21703
train_mnist_8c74b_00002,TERMINATED,127.0.0.1:22484,0.1,0.001,0.81875,10,5.03465
train_mnist_8c74b_00003,TERMINATED,127.0.0.1:29268,0.001,0.01,0.1625,10,5.98899
train_mnist_8c74b_00004,TERMINATED,127.0.0.1:23452,0.01,0.01,0.76875,10,5.99368
train_mnist_8c74b_00005,TERMINATED,127.0.0.1:22648,0.1,0.01,0.825,10,4.77235
train_mnist_8c74b_00006,TERMINATED,127.0.0.1:11056,0.001,0.1,0.215625,10,5.83804
train_mnist_8c74b_00007,TERMINATED,127.0.0.1:21600,0.01,0.1,0.8,10,4.605
train_mnist_8c74b_00008,TERMINATED,127.0.0.1:27836,0.1,0.1,0.896875,10,4.63792
train_mnist_8c74b_00009,TERMINATED,127.0.0.1:27632,0.001,0.9,0.7125,10,4.03019


  0%|          | 0.00/9.91M [00:00<?, ?B/s]
 10%|▉         | 983k/9.91M [00:00<00:01, 8.83MB/s]
 39%|███▊      | 3.83M/9.91M [00:00<00:00, 19.2MB/s]
 62%|██████▏   | 6.13M/9.91M [00:00<00:00, 20.7MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 21.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.46MB/s]
  0%|          | 0.00/1.65M [00:00<?, ?B/s]
 20%|█▉        | 328k/1.65M [00:00<00:00, 3.27MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.47MB/s]
100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]
2026-01-18 17:18:29,102	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to 'C:/Users/gyanr/ray_results/train_mnist_2026-01-18_17-18-00' in 0.0470s.
2026-01-18 17:18:29,110	INFO tune.py:1041 -- Total run time: 29.04 seconds (28.94 seconds for the tuning loop).


Done
CPU times: total: 1.58 s
Wall time: 29.5 s


In [20]:
print("Best config: ", analysis_func.get_best_config(metric="mean_accuracy", mode="max"))

Best config:  {'lr': 0.1, 'momentum': 0.1}


In [21]:
analysis_func.dataframe().sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,timestamp,checkpoint_dir_name,done,training_iteration,trial_id,date,time_this_iter_s,time_total_s,pid,hostname,node_ip,time_since_restore,iterations_since_restore,config/lr,config/momentum,logdir
8,0.896875,1768774708,,False,10,8c74b_00008,2026-01-18_17-18-28,0.217912,4.637921,27836,gylenovo,127.0.0.1,4.637921,10,0.1,0.1,8c74b_00008
10,0.8625,1768774708,,False,10,8c74b_00010,2026-01-18_17-18-28,0.16312,4.519816,23744,gylenovo,127.0.0.1,4.519816,10,0.01,0.9,8c74b_00010
11,0.853125,1768774708,,False,10,8c74b_00011,2026-01-18_17-18-28,0.149197,4.985619,4272,gylenovo,127.0.0.1,4.985619,10,0.1,0.9,8c74b_00011
5,0.825,1768774707,,False,10,8c74b_00005,2026-01-18_17-18-27,0.300431,4.772352,22648,gylenovo,127.0.0.1,4.772352,10,0.1,0.01,8c74b_00005
2,0.81875,1768774707,,False,10,8c74b_00002,2026-01-18_17-18-27,0.217618,5.034654,22484,gylenovo,127.0.0.1,5.034654,10,0.1,0.001,8c74b_00002


In [22]:
analysis_func.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

Unnamed: 0,mean_accuracy,config/lr,config/momentum
8,0.896875,0.1,0.1
10,0.8625,0.01,0.9
11,0.853125,0.1,0.9
5,0.825,0.1,0.01
2,0.81875,0.1,0.001
7,0.8,0.01,0.1
4,0.76875,0.01,0.01
1,0.753125,0.01,0.001
9,0.7125,0.001,0.9
6,0.215625,0.001,0.1


How long did it take? We'll compare this value with a different training run in the next lesson.

In [26]:
print(f"Total time: {analysis_func.results_df['time_total_s'].sum()}")


Total time: 62.05755257606506


### Use Tune's Trainable Class API

As a subclass of `tune.Trainable`, Tune will create a Trainable object on a separate process (using the [Ray Actor API](https://docs.ray.io/en/latest/actors.html#actor-guide)).

 * setup function is invoked once training starts.
 * step is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.


In [27]:
class TrainMNIST(tune.Trainable):
    def setup(self, config):
        self.config = config
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["lr"])
    
    def step(self):
        train(self.model, self.optimizer, self.train_loader)
        acc = test(self.model, self.test_loader)
        return {"mean_accuracy": acc}

In [28]:
%%time
analysis = tune.run(
    TrainMNIST, 
    config=config,
    stop={"training_iteration": 20},
    verbose=1
)

0,1
Current time:,2026-01-18 17:21:34
Running for:,00:00:34.55
Memory:,14.8/15.7 GiB

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
TrainMNIST_f7732_00000,TERMINATED,127.0.0.1:19336,0.001,0.001,0.5375,20,7.43316
TrainMNIST_f7732_00001,TERMINATED,127.0.0.1:19072,0.01,0.001,0.821875,20,6.81005
TrainMNIST_f7732_00002,TERMINATED,127.0.0.1:27036,0.1,0.001,0.9125,20,6.9478
TrainMNIST_f7732_00003,TERMINATED,127.0.0.1:10020,0.001,0.01,0.1875,20,6.38798
TrainMNIST_f7732_00004,TERMINATED,127.0.0.1:14904,0.01,0.01,0.859375,20,7.17814
TrainMNIST_f7732_00005,TERMINATED,127.0.0.1:27640,0.1,0.01,0.915625,20,7.09179
TrainMNIST_f7732_00006,TERMINATED,127.0.0.1:16868,0.001,0.1,0.209375,20,6.82371
TrainMNIST_f7732_00007,TERMINATED,127.0.0.1:27744,0.01,0.1,0.8,20,6.82445
TrainMNIST_f7732_00008,TERMINATED,127.0.0.1:24368,0.1,0.1,0.93125,20,6.89844
TrainMNIST_f7732_00009,TERMINATED,127.0.0.1:3012,0.001,0.9,0.353125,20,5.58593


  0%|          | 0.00/9.91M [00:00<?, ?B/s]
  8%|▊         | 819k/9.91M [00:00<00:01, 8.15MB/s]
 30%|███       | 3.01M/9.91M [00:00<00:00, 14.2MB/s]
 55%|█████▍    | 5.41M/9.91M [00:00<00:00, 18.2MB/s]
 78%|███████▊  | 7.70M/9.91M [00:00<00:00, 19.1MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 19.2MB/s]
  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.33MB/s]
  0%|          | 0.00/1.65M [00:00<?, ?B/s]
 20%|█▉        | 328k/1.65M [00:00<00:00, 3.24MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.94MB/s]
100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]
2026-01-18 17:21:34,169	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to 'C:/Users/gyanr/ray_results/TrainMNIST_2026-01-18_17-20-59' in 0.0481s.
2026-01-18 17:21:34,173	INFO tune.py:1041 -- Total run time: 34.58 seconds (34.50 seconds for the tuning loop).


CPU times: total: 2.19 s
Wall time: 35 s


In [29]:
print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max"))

Best config:  {'lr': 0.1, 'momentum': 0.1}


In [30]:
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
df.head()

Unnamed: 0,mean_accuracy,done,training_iteration,trial_id,date,timestamp,time_this_iter_s,time_total_s,pid,hostname,node_ip,time_since_restore,iterations_since_restore,config/lr,config/momentum,logdir
0,0.5375,True,20,f7732_00000,2026-01-18_17-21-32,1768774892,0.362598,7.433163,19336,gylenovo,127.0.0.1,7.433163,20,0.001,0.001,f7732_00000
1,0.821875,True,20,f7732_00001,2026-01-18_17-21-32,1768774892,0.29761,6.81005,19072,gylenovo,127.0.0.1,6.81005,20,0.01,0.001,f7732_00001
2,0.9125,True,20,f7732_00002,2026-01-18_17-21-32,1768774892,0.274889,6.947797,27036,gylenovo,127.0.0.1,6.947797,20,0.1,0.001,f7732_00002
3,0.1875,True,20,f7732_00003,2026-01-18_17-21-32,1768774892,0.243473,6.38798,10020,gylenovo,127.0.0.1,6.38798,20,0.001,0.01,f7732_00003
4,0.859375,True,20,f7732_00004,2026-01-18_17-21-32,1768774892,0.259074,7.178136,14904,gylenovo,127.0.0.1,7.178136,20,0.01,0.01,f7732_00004


In [31]:
analysis.dataframe().sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,done,training_iteration,trial_id,date,timestamp,time_this_iter_s,time_total_s,pid,hostname,node_ip,time_since_restore,iterations_since_restore,config/lr,config/momentum,logdir
8,0.93125,True,20,f7732_00008,2026-01-18_17-21-31,1768774891,0.373442,6.898444,24368,gylenovo,127.0.0.1,6.898444,20,0.1,0.1,f7732_00008
5,0.915625,True,20,f7732_00005,2026-01-18_17-21-32,1768774892,0.225655,7.091789,27640,gylenovo,127.0.0.1,7.091789,20,0.1,0.01,f7732_00005
2,0.9125,True,20,f7732_00002,2026-01-18_17-21-32,1768774892,0.274889,6.947797,27036,gylenovo,127.0.0.1,6.947797,20,0.1,0.001,f7732_00002
11,0.90625,True,20,f7732_00011,2026-01-18_17-21-34,1768774894,0.125302,4.290864,25872,gylenovo,127.0.0.1,4.290864,20,0.1,0.9,f7732_00011
4,0.859375,True,20,f7732_00004,2026-01-18_17-21-32,1768774892,0.259074,7.178136,14904,gylenovo,127.0.0.1,7.178136,20,0.01,0.01,f7732_00004


It's easier to see what we want if project out the interesting columns:

In [32]:
analysis.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

Unnamed: 0,mean_accuracy,config/lr,config/momentum
8,0.93125,0.1,0.1
5,0.915625,0.1,0.01
2,0.9125,0.1,0.001
11,0.90625,0.1,0.9
4,0.859375,0.01,0.01
10,0.83125,0.01,0.9
1,0.821875,0.01,0.001
7,0.8,0.01,0.1
0,0.5375,0.001,0.001
9,0.353125,0.001,0.9


How long did it take? We'll compare this value with a different training run in the next lesson.

In [34]:
df = analysis.results_df
elapsed = df["time_total_s"].max()
print(f"{elapsed:7.2f} seconds, {elapsed/60:7.2f} minutes")


   7.43 seconds,    0.12 minutes


The next lesson will explore optimization algorithms that speed up HPO.

In [36]:
ray.shutdown()  # "Undo ray.init()".