## Example usage for DaskClassifier

In [1]:
import sys
import os
from pathlib import Path
DIR = Path(".").absolute()
sys.path.append(str(DIR))
os.chdir(str(DIR.parent)) # make notebook assume its in parent dir

In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import numpy as np
import json
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [3]:
from dask.distributed import Client

def _prep():
    from distributed.protocol import torch

client = Client(processes=False)
client.run(_prep)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 57458 instead


0,1
Client  Scheduler: inproc://10.64.32.28/60622/1  Dashboard: http://10.64.32.28:57458/status,Cluster  Workers: 1  Cores: 8  Memory: 17.18 GB


`processes=False` is a large performance increase because threads have much faster communication when a Dask Distributed client is present. At each step, the model is copied to each worker. That's almost instant with threads because the same memory bank is shared; that's not true with processes.

`processes=False` should still be used with GPUs. There, the model/gradient calculation will live on each GPU.

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from adadamp import DaskClassifier

In [6]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
from model import Net
client.upload_file("notebooks/model.py")

In [7]:
from torch.utils.data import Dataset
from dask.distributed import get_client

def run(
    model: nn.Module,
    train_set: Dataset,
    test_set: Dataset,
    max_epochs: int = 5,
):
    client = get_client()
    hist = []
    epochs = 0
    for epoch in range(max_epochs):
        print(f"Epoch {epoch}...", end=" ")
        model.partial_fit(train_set)
        print("done")
        model.score(test_set)  # records info in model.meta_
        datum = {"epoch": epoch + 1, **model.meta_}
        print(datum)
        hist.append(datum)
    return hist, model.get_params()

In [8]:
# transforms
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# data
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, transform=transform)

In [9]:
# model
model = DaskClassifier(
    module=Net,
    weight_decay=1e-5,
    loss=nn.CrossEntropyLoss,
    optimizer=optim.SGD,
    optimizer__lr=0.01,
    optimizer__momentum=0.9,
    batch_size=128,
)


In [10]:
args = (model, train_set, test_set)
kwargs = dict(max_epochs=10)
hist, params = run(*args, **kwargs)

Epoch 0... done
{'epoch': 1, 'n_updates': 469, 'n_data': 60032, 'score__calls': 1, 'partial_fit__calls': 1, 'partial_fit__time': 105.9611120223999, 'partial_fit__batch_size': 128, 'score__acc': 0.7397000193595886, 'score__loss': 2.7202408935546876, 'score__time': 14.996641159057617}
Epoch 1... done
{'epoch': 2, 'n_updates': 938, 'n_data': 120064, 'score__calls': 2, 'partial_fit__calls': 2, 'partial_fit__time': 99.81652903556824, 'partial_fit__batch_size': 128, 'score__acc': 0.7318999767303467, 'score__loss': 2.802848291015625, 'score__time': 14.40046501159668}
Epoch 2... done
{'epoch': 3, 'n_updates': 1407, 'n_data': 180096, 'score__calls': 3, 'partial_fit__calls': 3, 'partial_fit__time': 112.15792894363403, 'partial_fit__batch_size': 128, 'score__acc': 0.7470999956130981, 'score__loss': 2.8819994995117186, 'score__time': 13.294518947601318}
Epoch 3... done
{'epoch': 4, 'n_updates': 1876, 'n_data': 240128, 'score__calls': 4, 'partial_fit__calls': 4, 'partial_fit__time': 86.937776088714

In [11]:
hist[0]

{'epoch': 1,
 'n_updates': 469,
 'n_data': 60032,
 'score__calls': 1,
 'partial_fit__calls': 1,
 'partial_fit__time': 105.9611120223999,
 'partial_fit__batch_size': 128,
 'score__acc': 0.7397000193595886,
 'score__loss': 2.7202408935546876,
 'score__time': 14.996641159057617}

In [12]:
import pandas as pd
df = pd.DataFrame(hist)
avg_update_time = df["partial_fit__time"].sum() / df["n_updates"].max()
msg = "Avg. update time = {:0.0f}ms".format(1000 * avg_update_time)
print(msg)

Avg. update time = 218ms


In [13]:
save = {
    'history': hist,
    'params': {k: v for k, v in params.items() if type(v) != type},
}

In [14]:
with open('./notebooks/stats.json', 'w') as f:
    json.dump(save, f)