## Example usage for DaskClassifier

In [1]:
import os
os.chdir('..') # make notebook assume its in parent dir

In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import numpy as np
import json
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [3]:
from dask.distributed import Client
client = Client(processes=False)
client

0,1
Client  Scheduler: inproc://10.64.32.28/41934/1  Dashboard: http://10.64.32.28:8787/status,Cluster  Workers: 1  Cores: 8  Memory: 17.18 GB


`processes=False` is a large performance increase because threads have much faster communication when a Dask Distributed client is present. At each step, the model is copied to each worker. That's almost instant with threads because the same memory bank is shared; that's not true with processes.

`processes=False` should still be used with GPUs. There, the model/gradient calculation will live on each GPU.

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from adadamp import DaskClassifier

In [6]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [7]:
from torch.utils.data import Dataset

def run(
    model: nn.Module,
    train_set: Dataset,
    test_set: Dataset,
    max_epochs: int = 5,
):
    hist = []
    epochs = 0
    for epoch in range(max_epochs):
        model.partial_fit(train_set)
        model.score(test_set)  # records info in model.meta_
        datum = {"epoch": epoch + 1, **model.meta_}
        print(datum)
        hist.append(datum)
    return hist, model.get_params()

In [8]:
# transforms
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# data
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, transform=transform)

In [9]:
# model
model = DaskClassifier(
    module=Net,
    weight_decay=1e-5,
    loss=nn.NLLLoss,
    optimizer=optim.SGD,
    optimizer__lr=0.1e-3,
    optimizer__momentum=0.9,
    batch_size=128,
)


In [None]:
# run(model, device, train_set, test_set, max_epochs=10, log_every=100)
args = (model, train_set, test_set)
kwargs = dict(max_epochs=10)
future = client.submit(run, *args, **kwargs)
hist, params = client.gather(future)

In [None]:
hist[0]

In [21]:
save = {
    'history': hist,
    'params': {k: v for k, v in params.items() if type(v) != type},
}

In [25]:
with open('./notebooks/stats.json', 'w') as f:
    json.dump(save, f)