In [4]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from functools import lru_cache
import numpy as np
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from adadamp import DaskBaseDamper, DaskBaseDamper2

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
# model from https://github.com/pytorch/examples/blob/master/mnist/main.py
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [16]:
def setup_and_train(epochs=14):
    
    # params
    device = torch.device("cpu")
    log_interval = 10
    train_kwargs = {'batch_size': 64}
    
    # transform data
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    
    # load data
    dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
    dataset2 = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    
    # dask model
    model = DaskBaseDamper2(module=Net, loss=nn.NLLLoss, optimizer=optim.Adadelta, optimizer__lr=1.0, batch_size=64)

    # set up statistics
    stats = [] # will hold dictionary entries, one per epoch
    
    for epoch in range(1, epochs + 1):
        
        train(model, device, train_loader, epoch)
    

In [17]:
def train(model, device, train_loader, epoch):
    # per batch stats
    # - losses = loss for batch
    # - time_for_batch = time to proccess batch
    # - params = params during this batch
    # - batch_idx = index of current batch

    log_interval = 1000
    
    loss = 0
    accuracy = 0
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        model.fit(data, target)
        
        # outs = 64x
        new_loss, accuracy = model.score(data, target) # Expected input batch_size (640) to match target batch_size (64).

        loss += new_loss
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss ))

In [19]:
setup_and_train(1)



KeyboardInterrupt: 