## Training a Wide ResNet with Dask Classifier

In [1]:
from dask.distributed import Client

In [26]:
import os
os.chdir('/home/ubuntu/adadamp-experiments')

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import csv
from copy import copy
from adadamp.adadamp import DaskClassifier, DaskClassifierExpiriments

In [4]:
# training client
from dask.distributed import Client

def _prep():
    from distributed.protocol import torch

client = Client(processes=False)
client.run(_prep)
client

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


0,1
Client  Scheduler: inproc://172.31.40.124/6887/1  Dashboard: http://172.31.40.124:44411/status,Cluster  Workers: 1  Cores: 4  Memory: 16.48 GB


In [5]:
from model import Wide_ResNet

client.upload_file("./exp-dask/model.py")

In [6]:
# load data - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='./exp-dask/data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./exp-dask/data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [8]:
# model for expiriment 1: 
#    “Decaying learning rate” follows the original implementation; 
#     the batch size is constant, while the learning rate repeatedly 
#     decays by a factor of 5 at a sequence of steps
# my understanding is that this is the "control" expeririment, where we are not touching the number of 
# workers nor the batch size
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
model = DaskClassifierExpiriments(
    module=Wide_ResNet,
    module__depth=16,
    module__widen_factor=4,
    module__dropout_rate=0.3,
    module__num_classes=len(classes),
    loss=nn.CrossEntropyLoss,
    optimizer=torch.optim.SGD,
    optimizer__lr=0.1,
    optimizer__momentum=0.9,
    optimizer__nesterov=True,
    optimizer__weight_decay=0.5e-3,
    batch_size=128,
    max_epochs=200,
    device=device,
    grads_per_worker=128,
    client=client
)

In [None]:
def train(model, train_set, test_set, n_epochs=200, epoch_sched=[], lr_sched=[], bs_sched=[]):
    """
    Train based on expiriment params
    
    Parameters:
    epoch_sched: update lr and bs at epochs in this list
    lr_sched: update lr to value at matching epoch. Should be same length as epoch_sched
    bs_sched: update bs to value at matching epoch. Should be same length as epoch_sched
    """
    assert len(epoch_sched) == len(lr_sched) == len(bs_sched), "Invalid schedules. Epoch, lr and bs schedules should all be the same length."
    
    epochs = copy(epoch_sched)
    lrs = copy(lr_sched)
    bss = copy(bs_sched)
    
    history = []
    for epoch in range(n_epochs):
        # check for updates
        if len(epochs) > 0 and epochs[0] == epoch:
            lr = lrs.pop(0)
            bs = bss.pop(0)
            epochs.pop(0)
            model.set_lr(lr)
            model.set_bs(bs)
            print("[Epoch {}] Updated model params:\n\tlr: {}\n\tbs: {}".format(epoch, lr, bs))
        # run
        model.partial_fit(train_set)
        score = model.score(test_set)
        datum = {"epoch": epoch, "score": score, **model.get_params(), **model.meta_}
        print("[Epoch {}] Score: {}".format(epoch, score))
        history.append(datum)
        
    return history

In [29]:
# ie "Update LR to 0.1 and bs t0 640 on 60th epoch"
exp1_epochs = [0, 60, 120, 180]
exp1_lr = [0.1, 0.1, 0.1, 0.1]
exp1_bs = [128, 640, 3200, 16000]
# train
hist = train(model, train_set, test_set, n_epochs=200, epoch_sched=exp1_epochs, lr_sched=exp1_lr, bs_sched=exp1_bs)

Updated model params on epoch 0:
	lr: 0.1
	bs: 128
[Epoch 0] Score: 0.10239999741315842


KeyboardInterrupt: 

In [None]:
# ie "Update LR to 0.1 and bs t0 640 on 60th epoch"
exp2_epochs = [0, 60, 120, 180]
exp2_lr = [0.1, 
           0.1, 
           0.1 / 5, 
           0.1 / 5 / 5
          ]
exp2_bs = [128, 640, 640, 640]
# train
hist = train(model, train_set, test_set, n_epochs=200, epoch_sched=exp1_epochs, lr_sched=exp1_lr, bs_sched=exp1_bs)

In [None]:
# ie "Update LR to 0.1 and bs t0 640 on 60th epoch"
exp3_epochs = [0, 60, 120, 180]
exp3_lr = [0.1, 
           0.1 / 5, 
           0.1 / 5 / 5,
           0.1 / 5 / 5 / 5
          ]
exp3_bs = [128, 128, 128, 128]
# train
hist = train(model, train_set, test_set, n_epochs=200, epoch_sched=exp1_epochs, lr_sched=exp1_lr, bs_sched=exp1_bs)

In [None]:
toCSV = model.curr_metas
with open('./exp-dask/exp1-decreaseingLR-const-workers-v0.csv', 'w', encoding='utf8', newline='') as output_file:
    fc = csv.DictWriter(output_file, fieldnames=toCSV[0].keys())
    fc.writeheader()
    fc.writerows(toCSV)