## Example usage for DaskClassifier

In [2]:
import sys
import os
from pathlib import Path
DIR = Path(".").absolute()
sys.path.append(str(DIR))
os.chdir(str(DIR.parent)) # make notebook assume its in parent dir

In [3]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import numpy as np
import json
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [4]:
from dask.distributed import Client

def _prep():
    from distributed.protocol import torch

client = Client(processes=False)
client.run(_prep)
client

0,1
Client  Scheduler: inproc://172.31.40.124/5147/1  Dashboard: http://172.31.40.124:8787/status,Cluster  Workers: 1  Cores: 4  Memory: 16.48 GB


`processes=False` is a large performance increase because threads have much faster communication when a Dask Distributed client is present. At each step, the model is copied to each worker. That's almost instant with threads because the same memory bank is shared; that's not true with processes.

`processes=False` should still be used with GPUs. There, the model/gradient calculation will live on each GPU.

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
from adadamp import DaskClassifier

In [7]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
from model import Net
client.upload_file("notebooks/model.py")

In [8]:
def get_model_weights(model):
    s = 0
    for param in model.parameters():
        s += torch.abs(torch.sum(param)).item()
    return s

In [9]:
from torch.utils.data import Dataset
from dask.distributed import get_client

def run(
    model: nn.Module,
    train_set: Dataset,
    test_set: Dataset,
    max_epochs: int = 5,
):
    client = get_client()
    hist = []
    epochs = 0
    for epoch in range(max_epochs):
        # train
        print(f"Epoch {epoch}...", end=" ")
        pre_weights = get_model_weights(model.module_)
        model.partial_fit(train_set)
        print("done")
        
        # ensure update
        assert pre_weights != get_model_weights(model.module_), "ERROR: Model weights not changed after partial fit"
        
        # test model
        # temporarily using train set for testing
        model.score(train_set) # records info in model.meta_
        datum = {"epoch": epoch + 1, **model.meta_}
        print(datum)
        hist.append(datum)
    return hist, model.get_params()

In [10]:
# transforms
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# data
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, transform=transform)

In [14]:
# model
model = DaskClassifier(
    module=Net,
    weight_decay=1e-5,
    loss=nn.CrossEntropyLoss,
    optimizer=optim.Adagrad,
    batch_size=1024,
    device="cpu" if not torch.cuda.is_available() else "cuda:0"
)
model.initialize()

In [15]:
model

<adadamp._dist.DaskClassifier at 0x7fb389d90f50>

In [None]:
args = (model, train_set, test_set)
kwargs = dict(max_epochs=20)
hist, params = run(*args, **kwargs)

Epoch 0... done
{'epoch': 1, 'n_updates': 59, 'n_data': 60416, 'score__calls': 1, 'partial_fit__calls': 1, 'n_workers': 32, 'partial_fit__time': 27.53844165802002, 'partial_fit__batch_size': 1024, 'weight_aggregate': 445.99322985112667, 'score__acc': 0.9115833640098572, 'score__loss': 0.3206614746729533, 'score__time': 9.681713342666626}
Epoch 1... done
{'epoch': 2, 'n_updates': 118, 'n_data': 120832, 'score__calls': 2, 'partial_fit__calls': 2, 'n_workers': 32, 'partial_fit__time': 27.275994539260864, 'partial_fit__batch_size': 1024, 'weight_aggregate': 456.9124222099781, 'score__acc': 0.9200500249862671, 'score__loss': 0.32105222832361857, 'score__time': 9.63679027557373}
Epoch 2... done
{'epoch': 3, 'n_updates': 177, 'n_data': 181248, 'score__calls': 3, 'partial_fit__calls': 3, 'n_workers': 32, 'partial_fit__time': 27.26951766014099, 'partial_fit__batch_size': 1024, 'weight_aggregate': 468.07095888257027, 'score__acc': 0.9214833378791809, 'score__loss': 0.32481049439112347, 'score__t

In [None]:
hist[0]

In [12]:
import pandas as pd
df = pd.DataFrame(hist)
avg_update_time = df["partial_fit__time"].sum() / df["n_updates"].max()
msg = "Avg. update time = {:0.0f}ms".format(1000 * avg_update_time)
print(msg)

Avg. update time = 218ms


In [13]:
save = {
    'history': hist,
    'params': {k: v for k, v in params.items() if type(v) != type},
}

In [14]:
with open('./notebooks/stats.json', 'w') as f:
    json.dump(save, f)