## Training a Wide ResNet with Dask Classifier

In [None]:
from dask.distributed import Client

In [None]:
import os
os.chdir('/home/ubuntu/proj/adadamp-experiments')

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from adadamp.adadamp import DaskClassifier

In [None]:
# training client
from dask.distributed import Client

def _prep():
    from distributed.protocol import torch

client = Client('172.31.42.174:8786', processes=False)
client.run(_prep)
client

In [65]:
from model import Wide_ResNet

client.upload_file("./exp-dask/model.py")

In [68]:
# load data - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [69]:
from torch.utils.data import Dataset
from dask.distributed import get_client

def run(
    model: nn.Module,
    train_set: Dataset,
    test_set: Dataset,
    max_epochs: int = 5,
):
    client = get_client()
    hist = []
    epochs = 0
    for epoch in range(max_epochs):
        print(f"Epoch {epoch}...", end=" ")
        model.partial_fit(train_set)
        print("done")
        model.score(test_set)  # records info in model.meta_
        datum = {"epoch": epoch + 1, **model.meta_}
        print(datum)
        hist.append(datum)
    return hist, model.get_params()

In [70]:
# model for expiriment 1: 
#    “Decaying learning rate” follows the original implementation; 
#     the batch size is constant, while the learning rate repeatedly 
#     decays by a factor of 5 at a sequence of steps
# my understanding is that this is the "control" expeririment, where we are not touching the number of 
# workers nor the batch size
model = DaskClassifier(
    module=Wide_ResNet,
    module__depth=16,
    module__widen_factor=4,
    module__dropout_rate=0.3,
    module__num_classes=len(classes),
    loss=nn.CrossEntropyLoss,
    optimizer=torch.optim.SGD,
    optimizer__lr=0.1,
    optimizer__momentum=0.9,
    optimizer__nesterov=True,
    optimizer__weight_decay=0.5e-3,
    batch_size=128,
)

In [71]:
args = (model, train_set, test_set)
kwargs = dict(max_epochs=10)
hist, params = run(*args, **kwargs)

Epoch 0... 

TypeError: super(type, obj): obj must be an instance or subtype of type