In [1]:
import pickle
import glob
import os
import socket

import numpy as np
from math import ceil


def _get_dataset_func_map():
    return {
        'mnist': load_mnist,
#         'fashion-mnist': load_fashion_mnist,
#         'pubfig83': load_pubfig83,
#         'cifar10': load_cifar10,
#         'cifar20': load_cifar20,
#         'cifar100': load_cifar100,
#         'bsm': load_bsm,
#         'BraTS17': load_BraTS17,
    }


def get_dataset_list():
    return list(_get_dataset_func_map().keys())


def load_dataset(dataset, **kwargs):
    if dataset not in get_dataset_list():
        raise ValueError("Dataset {} not in list of datasets {get_dataset_list()}".format(dataset))
    return _get_dataset_func_map()[dataset](**kwargs)


def _get_dataset_dir(server=None):
    if server is None:
        server = socket.gethostname()
    server_to_path = {'spr-gpu01': os.path.join('/', 'raid', 'datasets'),
                      'edwardsb-Z270X-UD5': os.path.join('/', 'data'),
                      'msheller-ubuntu': os.path.join('/', 'home', 'msheller', 'datasets')}
    return server_to_path[server]


def _unpickle(file):
    with open(file, 'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
    return d


def _read_mnist(path, **kwargs):
    X_train, y_train = _read_mnist_kind(path, kind='train', **kwargs)
    X_test, y_test = _read_mnist_kind(path, kind='t10k', **kwargs)

    return X_train, y_train, X_test, y_test


# from https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py
def _read_mnist_kind(path, kind='train', one_hot=True, **kwargs):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    images = images.astype(float) / 255
    if one_hot:
        labels = _one_hot(labels.astype(np.int), 10)

    return images, labels


def load_mnist(**kwargs):
    path = os.path.join(_get_dataset_dir(), 'mnist', 'input_data')
    return _read_mnist(path, **kwargs)


def load_fashion_mnist(**kwargs):
    path = os.path.join(_get_dataset_dir(), 'fashion-mnist')
    return _read_mnist(path, **kwargs)


def _one_hot(y, n):
    return np.eye(n)[y]

In [2]:
import abc

class FLModel(metaclass=abc.ABCMeta):

    @abc.abstractmethod
    def get_tensor_dict(self):
        """Returns all parameters for aggregation, including optimizer parameters, if appropriate"""
        pass

    @abc.abstractmethod
    def set_tensor_dict(self, tensor_dict):
        """Returns all parameters for aggregation, including optimizer parameters, if appropriate"""
        pass

    @abc.abstractmethod
    def train_epoch(self):
        pass

    @abc.abstractmethod
    def get_training_data_size(self):
        pass

    @abc.abstractmethod
    def validate(self):
        pass

    @abc.abstractmethod
    def get_validation_data_size(self):
        pass


In [3]:
import torch
import torch.nn as nn


class PyTorchFLModel(FLModel, nn.Module):
    """WIP code. Goal is to simplify porting a model to this framework.
    Currently, this creates a placeholder and assign op for every variable, which grows the graph considerably.
    Also, the abstraction for the tf.session isn't ideal yet."""

    def __init__(self):
        # calls nn.Module init
        super(PyTorchFLModel, self).__init__()

    @abc.abstractmethod
    def get_optimizer(self):
        pass

    def get_optimizer_tensors(self):
        optimizer = self.get_optimizer()

        tensor_dict = {}

        state = optimizer.state_dict()['state']

        # FIXME: this is really fragile. Need to understand what could change here
        for i, sk in enumerate(state.keys()):
            if isinstance(state[sk], dict):
                for k, v in state[sk].items():
                    if isinstance(v, torch.Tensor):
                        tensor_dict['{}_{}'.format(i, k)] = v.cpu().numpy()

        return tensor_dict
                    
    def set_optimizer_tensors(self, tensor_dict):
        optimizer = self.get_optimizer()

        state = optimizer.state_dict()

        # FIXME: this is really fragile. Need to understand what could change here
        for i, sk in enumerate(state['state'].keys()):
            if isinstance(state['state'][sk], dict):
                for k, v in state['state'][sk].items():
                    if isinstance(v, torch.Tensor):
                        key = '{}_{}'.format(i, k)
                        
                        if key not in tensor_dict:
                            raise ValueError('{} not in keys: {}'.format(key, list(tensor_dict.keys())))
                        
                        state['state'][sk][k] = torch.Tensor(tensor_dict[key]).to(v.device)
        optimizer.load_state_dict(state)

    def get_tensor_dict(self):
        # FIXME: should we use self.parameters()??? Unclear if load_state_dict() is better or simple assignment is better
        # for now, state dict gives us names, which is good

        # FIXME: do both and sanity check each time?

        # FIXME: can this have values other than the tensors????
        state = self.state_dict()
        for k, v in state.items():
            state[k] = v.cpu().numpy() # get as a numpy array
        return {**state, **self.get_optimizer_tensors()}

    def set_tensor_dict(self, tensor_dict):
        # FIXME: should we use self.parameters()??? Unclear if load_state_dict() is better or simple assignment is better
        # for now, state dict gives us names, which is good
        
        # FIXME: do both and sanity check each time?

        # get the model state so that we can determine the correct tensor values/device placements
        model_state = self.state_dict()

        new_state = {}
        for k, v in model_state.items():
            new_state[k] = torch.Tensor(tensor_dict[k]).to(v.device)

        # set model state
        self.load_state_dict(new_state)

        # next we have the optimizer state
        self.set_optimizer_tensors(tensor_dict)


In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
import torch.nn.functional as F
import torch.optim as optim

class PyTorchMNISTCNN(PyTorchFLModel):

    def __init__(self, device, train_loader=None, val_loader=None):
        super(PyTorchMNISTCNN, self).__init__()

        self.device = device
        self.init_data_pipeline(train_loader, val_loader)
        self.init_network(device)
        self.init_optimizer()

    def create_loader(self, X, y, **kwargs):
        tX = torch.stack([torch.Tensor(i) for i in X])
        ty = torch.stack([torch.Tensor(i) for i in y])
        return torch.utils.data.DataLoader(torch.utils.data.TensorDataset(tX, ty), **kwargs)

    def init_data_pipeline(self, train_loader, val_loader):
        if train_loader is None or val_loader is None:
            X_train, y_train, X_val, y_val = load_dataset('mnist')
            X_train = X_train.reshape([-1, 1, 28, 28])
            X_val = X_val.reshape([-1, 1, 28, 28])

        if train_loader is None:
            self.train_loader = self.create_loader(X_train, y_train, batch_size=64, shuffle=True)
        else:
            self.train_loader = train_loader

        if val_loader is None:
            self.val_loader = self.create_loader(X_val, y_val, batch_size=64, shuffle=True)
        else:
            self.val_loader = val_loader

    def init_network(self, device):
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        self.to(device)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def init_optimizer(self):
        self.optimizer = optim.SGD(self.parameters(), lr=0.01, momentum=0.5)

    def get_optimizer(self):
        return self.optimizer

    def train_epoch(self):
        # set to "training" mode
        self.train()
        
        losses = []

        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device, dtype=torch.int64)
            self.optimizer.zero_grad()
            output = self(data)
            loss = F.cross_entropy(output, torch.max(target, 1)[1])
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach().cpu().numpy())
            
        return np.mean(losses)

    def get_training_data_size(self):
        return len(self.train_loader.dataset)

    def validate(self):
        self.eval()
        correct = 0

        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device, dtype=torch.int64)
                output = self(data)
                pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                target = torch.max(target, 1)[1]
                # FIXME: there has to be a better way than exhaustive eq then diagonal
                eq = pred.eq(target).diag().sum().cpu().numpy()
                correct += eq

        return correct / self.get_validation_data_size()

    def get_validation_data_size(self):
        return len(self.val_loader.dataset)


In [5]:
device = torch.device("cuda")

In [6]:
cnn = PyTorchMNISTCNN(device)

In [7]:
# we have an issue where we don't get any optimizer params until after some training
cnn.train_epoch()

0.7723057

In [8]:
d = cnn.get_tensor_dict()

In [9]:
d.keys()

dict_keys(['fc1.weight', 'fc1.bias', 'conv1.bias', 'conv1.weight', 'conv2.bias', '3_momentum_buffer', 'conv2.weight', '2_momentum_buffer', '5_momentum_buffer', '1_momentum_buffer', '0_momentum_buffer', '6_momentum_buffer', 'fc2.bias', '4_momentum_buffer', '7_momentum_buffer', 'fc2.weight'])

In [10]:
cnn.validate()

0.9435

In [11]:
cnn.train_epoch()

0.15161966

In [12]:
cnn.validate()

0.9713

In [13]:
cnn.set_tensor_dict(d)

In [14]:
cnn.validate()

0.9435

In [15]:
d2 = cnn.get_tensor_dict()

In [16]:
len(np.union1d(d.keys(), d2.keys())[0]) == len(d.keys()) and \
np.all([np.all(d[k] == d2[k]) for k in d.keys()])

True

In [27]:
a = {'a': 1, 'b': 2}
b = {'b': 3, 'a': 3}
c = {'c': 4}

a.keys() == c.keys()

False