In [1]:
%%capture
!pip install ray
!pip install tensorboardX

In [2]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler  # 성능이 떨어지는 시도를 조기에 종료



In [3]:
def load_data(data_dir='./data'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )
    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )

    return trainset, testset

## Model define

In [4]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
def train_cifar(config, checkpoint_dir=None, data_dir=None):
    net = Net(config['l1'], config['l2'])
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)  # 다중 gpu
    net.to(device) 
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9)

    if checkpoint_dir:
        model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, 'checkpoint'))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    trainset, testset = load_data(data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs]  # 앞이 val 뒤가 train
    )
    
    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=4)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=4)
    
    for epoch in range(10):
        running_loss = 0.0
        epoch_steps = 0
        net.train()
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_steps+=1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0  # 2000 iter만큼의 loss 평균계산

        net.eval()
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                predicted = torch.argmax(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, 'checkpoint')
            torch.save((net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy = correct / total)  # validation의 1epoch의 평균 loss와 accuracy
    print("Finished Training")


In [6]:
def test_accuracy(net, device="cpu"):
    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [7]:
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=1):
    data_dir = os.path.abspath('./data')
    load_data(data_dir)

    config= {
    'l1': tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    'l2': tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    'lr': tune.loguniform(1e-4, 1e-1),  # 균일하게 샘플링
    'batch_size': tune.choice([2, 4, 8, 16])
    }

    scheduler = ASHAScheduler(
        metric='loss',
        mode='min',
        max_t = max_num_epochs,
        grace_period=1,
        reduction_factor=2
    )
    reporter = CLIReporter(
        metric_columns = ['loss','accuracy','training_iteration']
    )
    result = tune.run(
        partial(train_cifar, data_dir=data_dir),
        resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter)
    
    best_trial = result.get_best_trial('loss','min','last')
    print('Best trial config: {}'.format(best_trial.config))
    print('Best trial final validation loss: {}'.format(
        best_trial.last_result['loss']
    ))
    print('Best trial final validation accuracy: {}'.format(
        best_trial.last_result['accuracy']
    ))

    best_trained_model = Net(best_trial.config['l1'], best_trial.config['l2'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        best_trained_model = nn.DataParallel(best_trained_model)  # 다중 gpu
    best_trained_model.to(device) 

    best_checkpoint_dir = best_trial.checkpoint.value
    model_state, optimizer_state = torch.load(os.path.join(
        best_checkpoint_dir, 'checkpoint'
    ))
    best_trained_model.load_state_dict(model_state)

    test_acc = test_accuracy(best_trained_model, device)
    print('Best trial test set accuracy: {}'.format(test_acc))

In [8]:
if __name__ == '__main__':
    main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting /content/data/cifar-10-python.tar.gz to /content/data
Files already downloaded and verified


2021-06-05 14:44:58,392	INFO services.py:1269 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-06-05 14:44:59,991	INFO registry.py:65 -- Detected unknown callable for trainable. Converting to class.


== Status ==
Memory usage on this node: 1.5/25.5 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 2.0/4 CPUs, 0/1 GPUs, 0.0/15.01 GiB heap, 0.0/7.51 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2021-06-05_14-45-00
Number of trials: 10/10 (9 PENDING, 1 RUNNING)
+---------------------+----------+-------+--------------+------+------+-------------+
| Trial name          | status   | loc   |   batch_size |   l1 |   l2 |          lr |
|---------------------+----------+-------+--------------+------+------+-------------|
| DEFAULT_9a930_00000 | RUNNING  |       |            2 |    4 |   32 | 0.0865766   |
| DEFAULT_9a930_00001 | PENDING  |       |            2 |    4 |  128 | 0.0002696   |
| DEFAULT_9a930_00002 | PENDING  |       |            8 |  128 |  256 | 0.00268091  |
| DEFAULT_9a930_00003 | PENDING  |       |            4 |    4 |    4 | 0.000419058 |
| 



[2m[36m(pid=907)[0m Files already downloaded and verified
[2m[36m(pid=907)[0m Files already downloaded and verified
[2m[36m(pid=263)[0m [5,  4000] loss: 0.531
Result for DEFAULT_9a930_00002:
  accuracy: 0.5852
  date: 2021-06-05_14-49-01
  done: false
  experiment_id: 986236757ff3498cbaab052c46b73ebe
  hostname: 382b57b17247
  iterations_since_restore: 5
  loss: 1.1785512973427772
  node_ip: 172.28.0.2
  pid: 263
  should_checkpoint: true
  time_since_restore: 152.491952419281
  time_this_iter_s: 30.527422189712524
  time_total_s: 152.491952419281
  timestamp: 1622904541
  timesteps_since_restore: 0
  training_iteration: 5
  trial_id: 9a930_00002
  
== Status ==
Memory usage on this node: 1.9/25.5 GiB
Using AsyncHyperBand: num_stopped=3
Bracket: Iter 8.000: None | Iter 4.000: -1.2120213317751884 | Iter 2.000: -1.3779540592521429 | Iter 1.000: -1.8466704591214658
Resources requested: 4.0/4 CPUs, 0/1 GPUs, 0.0/15.01 GiB heap, 0.0/7.51 GiB objects (0.0/1.0 accelerator_type:P100)




[2m[36m(pid=1051)[0m Files already downloaded and verified
[2m[36m(pid=263)[0m [6,  2000] loss: 0.966
[2m[36m(pid=1051)[0m Files already downloaded and verified
[2m[36m(pid=263)[0m [6,  4000] loss: 0.497
[2m[36m(pid=1051)[0m [1,  2000] loss: 2.303
Result for DEFAULT_9a930_00002:
  accuracy: 0.59
  date: 2021-06-05_14-49-31
  done: false
  experiment_id: 986236757ff3498cbaab052c46b73ebe
  hostname: 382b57b17247
  iterations_since_restore: 6
  loss: 1.2135483667850495
  node_ip: 172.28.0.2
  pid: 263
  should_checkpoint: true
  time_since_restore: 182.64447927474976
  time_this_iter_s: 30.15252685546875
  time_total_s: 182.64447927474976
  timestamp: 1622904571
  timesteps_since_restore: 0
  training_iteration: 6
  trial_id: 9a930_00002
  
== Status ==
Memory usage on this node: 1.9/25.5 GiB
Using AsyncHyperBand: num_stopped=4
Bracket: Iter 8.000: None | Iter 4.000: -1.2120213317751884 | Iter 2.000: -1.3779540592521429 | Iter 1.000: -1.9670624036312103
Resources requested:



[2m[36m(pid=1258)[0m Files already downloaded and verified
[2m[36m(pid=263)[0m [8,  2000] loss: 0.858
[2m[36m(pid=1258)[0m Files already downloaded and verified
[2m[36m(pid=1258)[0m [1,  2000] loss: 2.321
[2m[36m(pid=263)[0m [8,  4000] loss: 0.450
[2m[36m(pid=1258)[0m [1,  4000] loss: 1.153
Result for DEFAULT_9a930_00002:
  accuracy: 0.597
  date: 2021-06-05_14-50-31
  done: false
  experiment_id: 986236757ff3498cbaab052c46b73ebe
  hostname: 382b57b17247
  iterations_since_restore: 8
  loss: 1.2321154302060604
  node_ip: 172.28.0.2
  pid: 263
  should_checkpoint: true
  time_since_restore: 242.07878518104553
  time_this_iter_s: 29.598763704299927
  time_total_s: 242.07878518104553
  timestamp: 1622904631
  timesteps_since_restore: 0
  training_iteration: 8
  trial_id: 9a930_00002
  
== Status ==
Memory usage on this node: 1.9/25.5 GiB
Using AsyncHyperBand: num_stopped=5
Bracket: Iter 8.000: -1.2321154302060604 | Iter 4.000: -1.2120213317751884 | Iter 2.000: -1.3779540



[2m[36m(pid=1496)[0m Files already downloaded and verified
[2m[36m(pid=1496)[0m Files already downloaded and verified
Result for DEFAULT_9a930_00006:
  accuracy: 0.2175
  date: 2021-06-05_14-51-37
  done: false
  experiment_id: 28afb8e9d59e417b8fa7e726553cc772
  hostname: 382b57b17247
  iterations_since_restore: 1
  loss: 1.9607061433076858
  node_ip: 172.28.0.2
  pid: 1258
  should_checkpoint: true
  time_since_restore: 87.57974410057068
  time_this_iter_s: 87.57974410057068
  time_total_s: 87.57974410057068
  timestamp: 1622904697
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 9a930_00006
  
== Status ==
Memory usage on this node: 1.8/25.5 GiB
Using AsyncHyperBand: num_stopped=6
Bracket: Iter 8.000: -1.2321154302060604 | Iter 4.000: -1.2120213317751884 | Iter 2.000: -1.3779540592521429 | Iter 1.000: -1.9670624036312103
Resources requested: 4.0/4 CPUs, 0/1 GPUs, 0.0/15.01 GiB heap, 0.0/7.51 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_re



[2m[36m(pid=1496)[0m [2, 10000] loss: 0.331
[2m[36m(pid=1681)[0m Files already downloaded and verified
[2m[36m(pid=1681)[0m Files already downloaded and verified
Result for DEFAULT_9a930_00007:
  accuracy: 0.3439
  date: 2021-06-05_14-53-09
  done: true
  experiment_id: 90cad9c44fde443e996f56954cd12cb8
  hostname: 382b57b17247
  iterations_since_restore: 2
  loss: 1.652332221508026
  node_ip: 172.28.0.2
  pid: 1496
  should_checkpoint: true
  time_since_restore: 95.02956581115723
  time_this_iter_s: 45.57747721672058
  time_total_s: 95.02956581115723
  timestamp: 1622904789
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: 9a930_00007
  
== Status ==
Memory usage on this node: 1.9/25.5 GiB
Using AsyncHyperBand: num_stopped=8
Bracket: Iter 8.000: -1.2321154302060604 | Iter 4.000: -1.2120213317751884 | Iter 2.000: -1.5673164680689573 | Iter 1.000: -1.9638842734694482
Resources requested: 4.0/4 CPUs, 0/1 GPUs, 0.0/15.01 GiB heap, 0.0/7.51 GiB objects (0.0/1.0 accel



[2m[36m(pid=1766)[0m Files already downloaded and verified
[2m[36m(pid=1766)[0m Files already downloaded and verified
[2m[36m(pid=1681)[0m [1,  2000] loss: 2.304
[2m[36m(pid=1681)[0m [1,  4000] loss: 1.131
[2m[36m(pid=1766)[0m [1,  2000] loss: 1.863
[2m[36m(pid=1681)[0m [1,  6000] loss: 0.658
[2m[36m(pid=1766)[0m [1,  4000] loss: 0.810
[2m[36m(pid=1681)[0m [1,  8000] loss: 0.463
Result for DEFAULT_9a930_00009:
  accuracy: 0.4502
  date: 2021-06-05_14-53-43
  done: false
  experiment_id: 26af2145d3a0426483b72efac5469fb3
  hostname: 382b57b17247
  iterations_since_restore: 1
  loss: 1.5222246812343598
  node_ip: 172.28.0.2
  pid: 1766
  should_checkpoint: true
  time_since_restore: 32.17391896247864
  time_this_iter_s: 32.17391896247864
  time_total_s: 32.17391896247864
  timestamp: 1622904823
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 9a930_00009
  
== Status ==
Memory usage on this node: 1.9/25.5 GiB
Using AsyncHyperBand: num_stopped=8
Brac

2021-06-05 14:55:06,051	INFO tune.py:549 -- Total run time: 606.07 seconds (605.79 seconds for the tuning loop).


Result for DEFAULT_9a930_00009:
  accuracy: 0.4681
  date: 2021-06-05_14-55-05
  done: true
  experiment_id: 26af2145d3a0426483b72efac5469fb3
  hostname: 382b57b17247
  iterations_since_restore: 4
  loss: 1.5232763239383698
  node_ip: 172.28.0.2
  pid: 1766
  should_checkpoint: true
  time_since_restore: 114.65959119796753
  time_this_iter_s: 22.97699475288391
  time_total_s: 114.65959119796753
  timestamp: 1622904905
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: 9a930_00009
  
== Status ==
Memory usage on this node: 1.6/25.5 GiB
Using AsyncHyperBand: num_stopped=10
Bracket: Iter 8.000: -1.2321154302060604 | Iter 4.000: -1.367648827856779 | Iter 2.000: -1.5349007671117783 | Iter 1.000: -1.913810465836525
Resources requested: 2.0/4 CPUs, 0/1 GPUs, 0.0/15.01 GiB heap, 0.0/7.51 GiB objects (0.0/1.0 accelerator_type:P100)
Result logdir: /root/ray_results/DEFAULT_2021-06-05_14-45-00
Number of trials: 10/10 (1 RUNNING, 9 TERMINATED)
+---------------------+------------+----