In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
import numpy as np

import itertools
from time import sleep, time
import toolz
import numpy as np

In [2]:
import distributed
distributed.__file__

'/Users/ssievert/Developer/dask/distributed/distributed/__init__.py'

### Model

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
def get_model(model=None):
    if model is None:
        return Net()
    return getattr(torchvision.models, model)()

resnet_models = [m for m in dir(torchvision.models)
                 if 'resnet' in m and m != 'resnet']
models = [get_model(model=model) for model in resnet_models]

In [4]:
from types import SimpleNamespace
args = SimpleNamespace(batch_size=64, test_batch_size=1000,
                       epochs=2, lr=0.01, momentum=0.5,
                       no_cuda=True, seed=42, log_interval=80)
    
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

### Dask

In [5]:
from dask.distributed import Client, wait
client = Client()
client

0,1
Client  Scheduler: tcp://127.0.0.1:56886  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 8  Cores: 8  Memory: 17.18 GB


### Parameter Server


In [6]:
import torch.optim as optim
import copy

def clone(model):
    return copy.deepcopy(model)

class PS:
    model = None
    n_steps = 0

    def __init__(self, model, args, device, max_iter=30):
        model = model.to(device)
        self.model = model
        self._model = clone(model)
        self.optimizer = optim.SGD(self._model.parameters(), lr=args.lr, momentum=args.momentum)
        self.n_steps = 0
        self._grads_recvd = 0
        self._updating = False
        self.max_iter = max_iter

    def pull(self):
        """
        For a worker to pull a model from this PS
        """
        return self.model

    def push(self, grads, key=None):
        """
        For a worker to push some gradients to this PS
        """
        if grads:
            if self.n_steps > self.max_iter:
                return None
            if key != self.n_steps or self._updating:
                assert key < self.n_steps
                return self.n_steps
            
            assert not self._updating
            if self._grads_recvd == 0:
                self.optimizer.zero_grad()
            
            self._grads_recvd += 1
            self.aggregate(grads)
            if self._grads_recvd == 4:
                self._updating = True
                self._grads_recvd = 0
                self.step()
                self.n_steps += 1
                self._updating = False
        return self.n_steps

    def aggregate(self, grads):
        for name, param in self._model.named_parameters():
            if param.grad is None:
                param.grad = 0 * param
            param.grad += grads[name]
            
    def step(self):
        self.optimizer.step()
        self.model = clone(self._model)

# Execution

In [7]:
model = Net()

In [8]:
iters = 50
batch_size = 128
num_workers = 4

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

device = torch.device("cuda" if use_cuda else "cpu")
train_loader, model = client.scatter([train_loader, model])

ps = client.submit(PS, model, args, device,
                   actor=True, max_iter=30).result()


In [9]:
def train(model, device, data, target):
    model.train()
    data, target = data.to(device), target.to(device)
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    return model

In [10]:
import toolz
def worker(ps, device, train_loader):
    last_step = -1
    step = 0
    train_loader = iter(train_loader)
    start_step = ps.push(None).result()
    while True:
        if step != last_step:
            model = ps.pull().result()
            last_step = step
        data, target = next(train_loader)
        model = train(model, device, data, target)
        
        check = toolz.first(model.parameters())
        check = check.detach().numpy().flat[:3]
        print(step, check)
        
        grads = {name: p.grad.data for name, p in model.named_parameters()}
        
        step = ps.push(grads, key=step).result()
        if step is None:
            break

In [11]:
futures = [client.submit(worker, ps, device, train_loader)
           for i in range(num_workers)]
wait(futures);

In [12]:
model = ps.model
model

Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)

### Serialization time

Wait time is a lot of serialization time, + 1-2ms each way

In [13]:
%timeit model = ps.model

100 loops, best of 3: 6.46 ms per loop


In [14]:
from distributed.protocol import serialize, deserialize
%timeit _ = deserialize(*serialize(model))

100 loops, best of 3: 3.62 ms per loop


In [21]:
%load_ext snakeviz

In [23]:
%%snakeviz  
for i in range(100):
    deserialize(*serialize(model))

 
*** Profile stats marshalled to file '/var/folders/kk/wvsqd9_j5j12y8bszfc6gr9h0000gp/T/tmpyi76hecc'. 
