In [1]:
!pip install ray

Collecting ray
  Downloading ray-1.10.0-cp37-cp37m-manylinux2014_x86_64.whl (59.6 MB)
[K     |████████████████████████████████| 59.6 MB 1.6 MB/s 
Collecting redis>=3.5.0
  Downloading redis-4.1.4-py3-none-any.whl (175 kB)
[K     |████████████████████████████████| 175 kB 68.4 MB/s 
Collecting deprecated>=1.2.3
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Installing collected packages: deprecated, redis, ray
Successfully installed deprecated-1.2.13 ray-1.10.0 redis-4.1.4


In [2]:
!pip install torchvision



In [3]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [4]:
def load_data(data_dir='./data'):
  transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])
  trainset = torchvision.datasets.CIFAR10(
    root=data_dir, train=True, download=True, transform=transform
  )
  testset = torchvision.datasets.CIFAR10(
    root=data_dir, train=False, download=True, transform=transform 
  )
  return trainset, testset

In [5]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
def train_cifar(config, checkpoint_dir=None, data_dir=None):
  net = Net(config['l1'], config['l2'])

  device = 'cpu'
  if torch.cuda.is_available():
    device = 'cuda:0'
    if torch.cuda.device_count() > 1:
      net = nn.DataParallel(net)
  net.to(device)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9)

  if checkpoint_dir:
    model_state, optimizer_state = torch.load(
        os.path.join(checkpoint_dir, 'checkpoint')
    )
    net.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)

  trainset, testset = load_data(data_dir)

  test_abs = int(len(trainset) * 0.8)
  train_subset, val_subset = random_split(
      trainset, [test_abs, len(trainset) - test_abs]
  )

  trainloader = torch.utils.data.DataLoader(
      train_subset, 
      batch_size=int(config['batch_size']),
      shuffle=True,
      num_workers=8
  )
  valloader = torch.utils.data.DataLoader(
      val_subset, 
      batch_size=int(config['batch_size']),
      shuffle=True,
      num_workers=8
  )

  for epoch in range(10):
    running_loss = 0.0
    epoch_steps = 0
    for i, data in enumerate(trainloader, 0):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()
      epoch_steps += 1
      if i % 2000 == 1999:
        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / epoch_steps))
        running_loss = 0.0
  
    val_loss = 0.0
    val_steps = 0
    total = 0
    correct = 0
    for i, data in enumerate(valloader, 0):
      with torch.no_grad():
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loss = criterion(outputs, labels)
        val_loss += loss.cpu().numpy()
        val_steps += 1

    with tune.checkpoint_dir(epoch) as checkpoint_dir:
      path = os.path.join(checkpoint_dir, 'checkpoint')
      torch.save((net.state_dict(), optimizer.state_dict()), path)

    tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
  print('Finished training')



In [7]:
def test_accuracy(net, device="cpu"):
    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [8]:
from ray.tune import schedulers
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=1):
  data_dir = os.path.abspath('./data')
  load_data(data_dir)
  config = {
      "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 4)),
      "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 4)),
      "lr": tune.loguniform(1e-4, 1e-1),
      "batch_size": tune.choice([2, 4, 8])
  }
  scheduler = ASHAScheduler(
      metric='loss',
      mode='min',
      max_t=max_num_epochs,
      grace_period=1,
      reduction_factor=2
  )
  reporter = CLIReporter(
      metric_columns=['loss', 'accuracy', 'training_iteration']
  )
  result = tune.run(
      partial(train_cifar, data_dir=data_dir),
      resources_per_trial={'cpu': 1, 'gpu': gpus_per_trial},
      config=config,
      num_samples=num_samples,
      scheduler=scheduler,
      progress_reporter=reporter
  )

  best_trial = result.get_best_trial('loss', 'min', 'last')
  print('Best trial config: {}'.format(best_trial.config))
  print("Best trial final validation loss: {}".format(
      best_trial.last_result["loss"]))
  print("Best trial final validation accuracy: {}".format(
      best_trial.last_result["accuracy"]))
  
  best_trained_model = Net(best_trial.config['l1'], best_trial.config['l2'])
  device = 'cpu'
  if torch.cuda.is_available():
    device = 'cuda:0'
    if gpus_per_trial > 1:
      best_trained_model = nn.DataParallel(best_trained_model)
  best_trained_model.to(device)

  best_checkpoint_dir = best_trial.checkpoint.value
  model_state, optimizer_state = torch.load(
      os.path.join(best_checkpoint_dir, 'checkpoint')
  )
  best_trained_model.load_state_dict(model_state)

  test_acc = test_accuracy(best_trained_model, device)
  print('Best trial test set accuracy: {}'.format(test_acc))

In [9]:
main(num_samples=10, max_num_epochs=10, gpus_per_trial=1)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /content/data/cifar-10-python.tar.gz to /content/data
Files already downloaded and verified


2022-02-20 02:35:23,364	INFO registry.py:70 -- Detected unknown callable for trainable. Converting to class.
2022-02-20 02:35:23,400	INFO logger.py:606 -- pip install "ray[tune]" to see TensorBoard files.


== Status ==
Current time: 2022-02-20 02:35:23 (running for 00:00:00.25)
Memory usage on this node: 1.6/12.7 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (9 PENDING, 1 RUNNING)
+---------------------+----------+----------------+--------------+------+------+-------------+
| Trial name          | status   | loc            |   batch_size |   l1 |   l2 |          lr |
|---------------------+----------+----------------+--------------+------+------+-------------|
| DEFAULT_c1147_00000 | RUNNING  | 172.28.0.2:328 |            4 |    4 |    4 | 0.00839908  |
| DEFAULT_c1147_00001 | PENDING  |                |            4 |    4 |    4 | 0.0657213   |
| DEFAULT_c1147_00002 | PENDING  |                |            2 |  

[2m[36m(func pid=328)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 02:35:39 (running for 00:00:16.20)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (9 PENDING, 1 RUNNING)
+---------------------+----------+----------------+--------------+------+------+-------------+
| Trial name          | status   | loc            |   batch_size |   l1 |   l2 |          lr |
|---------------------+----------+----------------+--------------+------+------+-------------|
| DEFAULT_c1147_00000 | RUNNING  | 172.28.0.2:328 |            4 |    4 |    4 | 0.00839908  |
| DEFAULT_c1147_00001 | PENDING  |                |            4 |    4 |    4 | 0.0657213   |
| DEFAULT_c1147_00002 | PENDING  |                |            2 |  

[2m[36m(func pid=329)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 02:43:31 (running for 00:08:08.42)
Memory usage on this node: 3.3/12.7 GiB
Using AsyncHyperBand: num_stopped=1
Bracket: Iter 8.000: -2.30607836933136 | Iter 4.000: -2.307292249584198 | Iter 2.000: -2.3080029556274413 | Iter 1.000: -2.3037743733882903
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (8 PENDING, 1 RUNNING, 1 TERMINATED)
+---------------------+------------+----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc            |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00001 | RUNNING    | 172.28

[2m[36m(func pid=2648)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 02:44:28 (running for 00:09:04.50)
Memory usage on this node: 3.3/12.7 GiB
Using AsyncHyperBand: num_stopped=2
Bracket: Iter 8.000: -2.30607836933136 | Iter 4.000: -2.307292249584198 | Iter 2.000: -2.3080029556274413 | Iter 1.000: -2.330776154947281
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (7 PENDING, 1 RUNNING, 2 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00002 | RUNNING    | 172.

[2m[36m(func pid=4754)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 02:58:18 (running for 00:22:54.61)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=3
Bracket: Iter 8.000: -2.305985100722313 | Iter 4.000: -2.3054033162117005 | Iter 2.000: -2.302004478538036 | Iter 1.000: -2.3037743733882903
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (6 PENDING, 1 RUNNING, 3 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00003 | RUNNING    | 17

[2m[36m(func pid=6731)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 03:03:09 (running for 00:27:45.68)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=4
Bracket: Iter 8.000: -2.305891832113266 | Iter 4.000: -2.303514382839203 | Iter 2.000: -2.296006001448631 | Iter 1.000: -2.1320536488890647
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (5 PENDING, 1 RUNNING, 4 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00004 | RUNNING    | 172

[2m[36m(func pid=6957)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 03:04:05 (running for 00:28:41.71)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=5
Bracket: Iter 8.000: -2.305891832113266 | Iter 4.000: -2.303514382839203 | Iter 2.000: -2.296006001448631 | Iter 1.000: -2.3037743733882903
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (4 PENDING, 1 RUNNING, 5 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00005 | RUNNING    | 172

[2m[36m(func pid=8918)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 03:09:01 (running for 00:33:37.79)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=6
Bracket: Iter 8.000: -2.1428248338222504 | Iter 4.000: -2.169320965766907 | Iter 2.000: -2.2916658594965935 | Iter 1.000: -2.302197839713097
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (3 PENDING, 1 RUNNING, 6 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00006 | RUNNING    | 17

[2m[36m(func pid=9144)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 03:09:57 (running for 00:34:33.83)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=7
Bracket: Iter 8.000: -2.1428248338222504 | Iter 4.000: -2.169320965766907 | Iter 2.000: -2.2916658594965935 | Iter 1.000: -2.3037743733882903
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (2 PENDING, 1 RUNNING, 7 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00007 | RUNNING    | 1

[2m[36m(func pid=9378)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 03:10:53 (running for 00:35:29.87)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=8
Bracket: Iter 8.000: -2.1428248338222504 | Iter 4.000: -2.169320965766907 | Iter 2.000: -2.2916658594965935 | Iter 1.000: -2.3055326722860334
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (1 PENDING, 1 RUNNING, 8 TERMINATED)
+---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc             |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+-----------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00008 | RUNNING    | 1

[2m[36m(func pid=11333)[0m   cpuset_checked))


== Status ==
Current time: 2022-02-20 03:15:55 (running for 00:40:31.88)
Memory usage on this node: 3.4/12.7 GiB
Using AsyncHyperBand: num_stopped=9
Bracket: Iter 8.000: -1.9797578355312346 | Iter 4.000: -2.0351275486946108 | Iter 2.000: -2.287325717544556 | Iter 1.000: -2.3037743733882903
Resources requested: 1.0/2 CPUs, 1.0/1 GPUs, 0.0/6.77 GiB heap, 0.0/3.38 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2022-02-20_02-35-23
Number of trials: 10/10 (1 RUNNING, 9 TERMINATED)
+---------------------+------------+------------------+--------------+------+------+-------------+---------+------------+----------------------+
| Trial name          | status     | loc              |   batch_size |   l1 |   l2 |          lr |    loss |   accuracy |   training_iteration |
|---------------------+------------+------------------+--------------+------+------+-------------+---------+------------+----------------------|
| DEFAULT_c1147_00009 | RUNNING    | 172.28.0.

2022-02-20 03:23:36,158	INFO tune.py:636 -- Total run time: 2892.80 seconds (2892.54 seconds for the tuning loop).


Best trial config: {'l1': 4, 'l2': 4, 'lr': 0.0024844230134879433, 'batch_size': 8}
Best trial final validation loss: 1.3967038196563721
Best trial final validation accuracy: 0.4974
Files already downloaded and verified
Files already downloaded and verified
Best trial test set accuracy: 0.5056
