## Example usage for DaskClassifier

In [1]:
import sys
import os
from pathlib import Path
DIR = Path(".").absolute()
sys.path.append(str(DIR))
os.chdir(str(DIR.parent)) # make notebook assume its in parent dir

In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import numpy as np
import json
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [3]:
from dask.distributed import Client

def _prep():
    from distributed.protocol import torch

client = Client(processes=False)
client.run(_prep)
client

0,1
Client  Scheduler: inproc://172.31.40.124/7483/1  Dashboard: http://172.31.40.124:8787/status,Cluster  Workers: 1  Cores: 4  Memory: 16.48 GB


`processes=False` is a large performance increase because threads have much faster communication when a Dask Distributed client is present. At each step, the model is copied to each worker. That's almost instant with threads because the same memory bank is shared; that's not true with processes.

`processes=False` should still be used with GPUs. There, the model/gradient calculation will live on each GPU.

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from adadamp import DaskClassifier

In [6]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
from model import Net
client.upload_file("notebooks/model.py")

In [7]:
def get_model_weights(model):
    s = 0
    for param in model.parameters():
        s += torch.abs(torch.sum(param)).item()
    return s

In [8]:
from torch.utils.data import Dataset
from dask.distributed import get_client

def run(
    model: nn.Module,
    train_set: Dataset,
    test_set: Dataset,
    max_epochs: int = 5,
):
    client = get_client()
    hist = []
    epochs = 0
    for epoch in range(max_epochs):
        # train
        print(f"Epoch {epoch}...", end=" ")
        pre_weights = get_model_weights(model.module_)
        model.partial_fit(train_set)
        print("done")
        
        # ensure update
        assert pre_weights != get_model_weights(model.module_), "ERROR: Model weights not changed after partial fit"
        
        # test model
        # temporarily using train set for testing
        model.score(train_set) # records info in model.meta_
        datum = {"epoch": epoch + 1, **model.meta_}
        print(datum)
        hist.append(datum)
    return hist, model.get_params()

In [9]:
# transforms
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# data
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, transform=transform)

In [10]:
# model
model = DaskClassifier(
    module=Net,
    weight_decay=1e-5,
    loss=nn.CrossEntropyLoss,
    optimizer=optim.SGD,
    optimizer__lr=0.001,
    optimizer__momentum=0.5,
    batch_size=128,
    device="cpu" if not torch.cuda.is_available() else "cuda:0"
)
model.initialize()

In [11]:
model

<adadamp._dist.DaskClassifier at 0x7f39b9900390>

In [14]:
args = (model, train_set, test_set)
kwargs = dict(max_epochs=20)
hist, params = run(*args, **kwargs)

Epoch 0... Current Batch: 0
Updating model
Current Batch: 1


Function:  gradient
args:      ((Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
), SGD (
Parameter Group 0
    dampening: 0
    lr: 0.001
    momentum: 0.5
    nesterov: False
    weight_decay: 0
)), Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           ))
kwargs:    {'device': device(type='cuda', index=0), 'loss': CrossEntropyLoss(), 'idx': array([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
       81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])}
Exception: AssertionError('Gradients not cleared b

AssertionError: Gradients not cleared before loss.backward()

In [None]:
hist[0]

In [12]:
import pandas as pd
df = pd.DataFrame(hist)
avg_update_time = df["partial_fit__time"].sum() / df["n_updates"].max()
msg = "Avg. update time = {:0.0f}ms".format(1000 * avg_update_time)
print(msg)

Avg. update time = 218ms


In [13]:
save = {
    'history': hist,
    'params': {k: v for k, v in params.items() if type(v) != type},
}

In [14]:
with open('./notebooks/stats.json', 'w') as f:
    json.dump(save, f)