## Example usage for DaskClassifier

In [1]:
import os
os.chdir('..') # make notebook assume its in parent dir

In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import numpy as np
import json
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from adadamp import DaskClassifier

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [5]:
def train(model, device, train_loader, n_epochs, log_every=100):
    # per batch stats
    # - losses = loss for batch
    # - time_for_batch = time to proccess batch
    # - params = params during this batch
    # - batch_idx = index of current batch

    log_interval = log_every
    
    for epoch in range(1, n_epochs + 1):
        accs = []
        for batch_idx, (data, target) in enumerate(train_loader):

            data, target = data.to(device), target.to(device)
            model.fit(data, target)

            # outs = 64x
            new_acc = model.score(data, target) # Expected input batch_size (640) to match target batch_size (64).

            accs += [new_acc]

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAccuracy: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), sum(accs) / len(accs) ))

In [6]:
# params
device = torch.device("cpu")
log_interval = 10
train_kwargs = {'batch_size': 64}

In [7]:
# transforms
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [8]:
# data
dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)

In [9]:
# model
model = DaskClassifier(module=Net, loss=nn.NLLLoss, optimizer=optim.Adadelta, optimizer__lr=1.0, batch_size=64)

In [11]:
train(model, device, train_loader, 10, log_every=100)



In [12]:
meta, batch = model.get_stats()

In [13]:
save_json = {
    'meta': meta,
    'batch': batch
}

In [14]:
with open('./notebooks/stats_10epoch.json', 'w') as fp:
    json.dump(save_json, fp)