In [10]:
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets

from ray import tune
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test

In [17]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    print(train_loader)
    model = ConvNet()
    print(model.parameters(), model)
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        tune.track.log(mean_accuracy=acc)


In [18]:
analysis = tune.run(
    train_mnist, config={"lr": tune.grid_search([0.01, 0.1, 0.3]),
                        "kernel_size": tune.grid_search([(3, 3), (5, 5)])})

print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

Trial name,status,loc,lr
train_mnist_00000,RUNNING,,0.01
train_mnist_00001,PENDING,,0.1
train_mnist_00002,PENDING,,0.3


[2m[36m(pid=19220)[0m 2020-06-22 17:17:08,837	INFO trainable.py:217 -- Getting current IP.
[2m[36m(pid=19220)[0m <torch.utils.data.dataloader.DataLoader object at 0x7fc68fc433c8>
[2m[36m(pid=19220)[0m <generator object Module.parameters at 0x7fc68fc335c8> ConvNet(
[2m[36m(pid=19220)[0m   (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
[2m[36m(pid=19220)[0m   (fc): Linear(in_features=192, out_features=10, bias=True)
[2m[36m(pid=19220)[0m )
[2m[36m(pid=19218)[0m 2020-06-22 17:17:09,509	INFO trainable.py:217 -- Getting current IP.
Result for train_mnist_00000:
  date: 2020-06-22_17-17-09
  done: false
  experiment_id: 525127be7dbf443382c032938a096105
  experiment_tag: 0_lr=0.01
  hostname: Ravis-MacBook-Pro.local
  iterations_since_restore: 1
  mean_accuracy: 0.159375
  node_ip: 10.9.12.130
  pid: 19220
  time_since_restore: 0.5281491279602051
  time_this_iter_s: 0.5281491279602051
  time_total_s: 0.5281491279602051
  timestamp: 1592871429
  timesteps_since_r

Trial name,status,loc,lr,acc,total time (s),iter
train_mnist_00000,RUNNING,10.9.12.130:19220,0.01,0.578125,2.96579,5
train_mnist_00001,RUNNING,10.9.12.130:19218,0.1,0.7,1.26454,2
train_mnist_00002,RUNNING,10.9.12.130:19219,0.3,0.628125,1.44808,2


Result for train_mnist_00001:
  date: 2020-06-22_17-17-15
  done: false
  experiment_id: e520d994287b4fec87a315061852c6ca
  experiment_tag: 1_lr=0.1
  hostname: Ravis-MacBook-Pro.local
  iterations_since_restore: 9
  mean_accuracy: 0.915625
  node_ip: 10.9.12.130
  pid: 19218
  time_since_restore: 5.47992205619812
  time_this_iter_s: 0.5470671653747559
  time_total_s: 5.47992205619812
  timestamp: 1592871435
  timesteps_since_restore: 0
  training_iteration: 8
  trial_id: '00001'
  
Result for train_mnist_00002:
  date: 2020-06-22_17-17-16
  done: false
  experiment_id: a239731db39a486c8f7ca05207aef46c
  experiment_tag: 2_lr=0.3
  hostname: Ravis-MacBook-Pro.local
  iterations_since_restore: 10
  mean_accuracy: 0.91875
  node_ip: 10.9.12.130
  pid: 19219
  time_since_restore: 5.6572489738464355
  time_this_iter_s: 0.3730130195617676
  time_total_s: 5.6572489738464355
  timestamp: 1592871436
  timesteps_since_restore: 0
  training_iteration: 9
  trial_id: '00002'
  


Trial name,status,loc,lr,acc,total time (s),iter
train_mnist_00000,TERMINATED,,0.01,0.690625,5.17549,9
train_mnist_00001,TERMINATED,,0.1,0.871875,6.09064,9
train_mnist_00002,TERMINATED,,0.3,0.91875,5.65725,9


Best config:  {'lr': 0.3}


In [16]:
dfs = analysis.trial_dataframes
[d.mean_accuracy.plot() for d in dfs.values()]

[<matplotlib.axes._subplots.AxesSubplot at 0x7fe8fc1e97b8>,
 <matplotlib.axes._subplots.AxesSubplot at 0x7fe8fc1e97b8>,
 <matplotlib.axes._subplots.AxesSubplot at 0x7fe8fc1e97b8>]