## Example usage for DaskClassifier

In [None]:
import sys
import os
from pathlib import Path
DIR = Path(".").absolute()
sys.path.append(str(DIR))
os.chdir(str(DIR.parent)) # make notebook assume its in parent dir

In [13]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import numpy as np
import json
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [14]:
from dask.distributed import Client

def _prep():
    from distributed.protocol import torch

client = Client(processes=False)
client.run(_prep)
client

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


0,1
Client  Scheduler: inproc://172.31.40.124/4271/12  Dashboard: http://172.31.40.124:35601/status,Cluster  Workers: 1  Cores: 4  Memory: 16.48 GB


`processes=False` is a large performance increase because threads have much faster communication when a Dask Distributed client is present. At each step, the model is copied to each worker. That's almost instant with threads because the same memory bank is shared; that's not true with processes.

`processes=False` should still be used with GPUs. There, the model/gradient calculation will live on each GPU.

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
from adadamp import DaskClassifier

In [17]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
from model import Net
client.upload_file("notebooks/model.py")

In [18]:
def get_model_weights(model):
    s = 0
    for param in model.parameters():
        s += torch.abs(torch.sum(param)).item()
    return s

In [19]:
from torch.utils.data import Dataset
from dask.distributed import get_client

def run(
    model: nn.Module,
    train_set: Dataset,
    test_set: Dataset,
    max_epochs: int = 5,
):
    client = get_client()
    hist = []
    epochs = 0
    for epoch in range(max_epochs):
        # train
        print(f"Epoch {epoch}...", end=" ")
        pre_weights = get_model_weights(model.module_)
        model.partial_fit(train_set)
        print("done")
        
        # ensure update
        assert pre_weights != get_model_weights(model.module_), "ERROR: Model weights not changed after partial fit"
        
        # test model
        # temporarily using train set for testing
        model.score(train_set) # records info in model.meta_
        datum = {"epoch": epoch + 1, **model.meta_}
        print(datum)
        hist.append(datum)
    return hist, model.get_params()

In [20]:
# transforms
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# data
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, transform=transform)

In [21]:
# model
model = DaskClassifier(
    module=Net,
    weight_decay=1e-5,
    loss=nn.CrossEntropyLoss,
    optimizer=optim.SGD,
    optimizer__lr=0.001,
    optimizer__momentum=0.9,
    batch_size=128,
    device="cpu" if not torch.cuda.is_available() else "cuda:0"
)
model.initialize()

In [22]:
model

<adadamp._dist.DaskClassifier at 0x7f8ad3909890>

In [23]:
args = (model, train_set, test_set)
kwargs = dict(max_epochs=20)
hist, params = run(*args, **kwargs)

Epoch 0... 

distributed.utils - ERROR - 'lengths'
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/distributed/utils.py", line 665, in log_errors
    yield
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/distributed/protocol/numpy.py", line 104, in deserialize_numpy_ndarray
    frames = merge_frames(header, frames)
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/distributed/protocol/utils.py", line 65, in merge_frames
    lengths = list(header["lengths"])
KeyError: 'lengths'
distributed.worker - ERROR - '_update_model-82ea7cbe6f59d3354e2ce9b05f024a3c'
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/distributed/worker.py", line 2473, in execute
    data[k] = self.data[k]
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/zict/buffer.py", line 70, in __getitem__


KeyboardInterrupt: 



In [None]:
hist[0]

In [12]:
import pandas as pd
df = pd.DataFrame(hist)
avg_update_time = df["partial_fit__time"].sum() / df["n_updates"].max()
msg = "Avg. update time = {:0.0f}ms".format(1000 * avg_update_time)
print(msg)

Avg. update time = 218ms


In [13]:
save = {
    'history': hist,
    'params': {k: v for k, v in params.items() if type(v) != type},
}

In [14]:
with open('./notebooks/stats.json', 'w') as f:
    json.dump(save, f)