In [1]:
# This script takes work from Pytorch MNIST and UNet Dev.ipynb, takes the UNet portion only, and makes 
# training and validation repeatable in order
# to test a method for saving and setting optimizer state (momentum, etc.) that remains consistent
# even when the restoring is done after restarting the kernel. The train_epoch method also has been reduced to training
# over a few batches (not a complete epoch) in order to make testing faster

In [2]:
import abc

class FLModel(metaclass=abc.ABCMeta):

    @abc.abstractmethod
    def get_tensor_dict(self):
        """Returns all parameters for aggregation, including optimizer parameters, if appropriate"""
        pass

    @abc.abstractmethod
    def set_tensor_dict(self, tensor_dict):
        """Returns all parameters for aggregation, including optimizer parameters, if appropriate"""
        pass

    @abc.abstractmethod
    def train_partial_epoch(self):
        pass

    @abc.abstractmethod
    def get_training_data_size(self):
        pass

    @abc.abstractmethod
    def validate(self):
        pass

    @abc.abstractmethod
    def get_validation_data_size(self):
        pass


In [3]:
import pickle
import glob
import os
import socket

import numpy as np
from math import ceil


def _get_dataset_func_map():
    return {
        'mnist': load_mnist,
#         'fashion-mnist': load_fashion_mnist,
#         'pubfig83': load_pubfig83,
#         'cifar10': load_cifar10,
#         'cifar20': load_cifar20,
#         'cifar100': load_cifar100,
#         'bsm': load_bsm,
#         'BraTS17': load_BraTS17,
    }


def get_dataset_list():
    return list(_get_dataset_func_map().keys())


def load_dataset(dataset, **kwargs):
    if dataset not in get_dataset_list():
        raise ValueError("Dataset {} not in list of datasets {get_dataset_list()}".format(dataset))
    return _get_dataset_func_map()[dataset](**kwargs)


def _get_dataset_dir(server='edwardsb-Z270x-UD5'):
    if server is None:
        server = socket.gethostname()
    server_to_path = {'spr-gpu01': os.path.join('/', 'raid', 'datasets'),
                      'edwardsb-Z270X-UD5': os.path.join('/', 'data'),
                      'msheller-ubuntu': os.path.join('/', 'home', 'msheller', 'datasets')}
    return server_to_path[server]


def _unpickle(file):
    with open(file, 'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
    return d


def _read_mnist(path, **kwargs):
    X_train, y_train = _read_mnist_kind(path, kind='train', **kwargs)
    X_test, y_test = _read_mnist_kind(path, kind='t10k', **kwargs)

    return X_train, y_train, X_test, y_test


# from https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py
def _read_mnist_kind(path, kind='train', one_hot=True, **kwargs):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    images = images.astype(float) / 255
    if one_hot:
        labels = _one_hot(labels.astype(np.int), 10)

    return images, labels


def load_mnist(**kwargs):
    path = os.path.join(_get_dataset_dir(), 'mnist', 'input_data')
    return _read_mnist(path, **kwargs)


def load_fashion_mnist(**kwargs):
    path = os.path.join(_get_dataset_dir(), 'fashion-mnist')
    return _read_mnist(path, **kwargs)


def _one_hot(y, n):
    return np.eye(n)[y]

In [4]:
import torch
import torch.nn as nn


class PyTorchFLModel(FLModel, nn.Module):
    """WIP code. Goal is to simplify porting a model to this framework.
    Currently, this creates a placeholder and assign op for every variable, which grows the graph considerably.
    Also, the abstraction for the tf.session isn't ideal yet. optimizer work is modeified by Brandon"""

    def __init__(self):
        # calls nn.Module init
        super(PyTorchFLModel, self).__init__()

    @abc.abstractmethod
    def get_optimizer(self):
        pass

    def get_optimizer_tensors(self):
        optimizer = self.get_optimizer()

        tensor_dict = {}

        # FIXME: not clear that this works consistently across optimizers
        # FIXME: hard-coded naming convention sucks and could absolutely break
        i = 0
        for group in optimizer.state_dict()['param_groups']:
            for key in group['params']:
                tensor_dict['__opt_{}'.format(i)] = \
                  optimizer.state_dict()['state'][key]['momentum_buffer'].detach().cpu().numpy()
                i += 1
        return tensor_dict
            
    def set_optimizer_tensors(self, tensor_dict):
        optimizer = self.get_optimizer()
        
        # Relies on consistent ordering of the keys obtained through optimizer.state_dict()['param_groups'][i]['params']
        # for each i, as well as the consistent ordering of these groups.
        # FIXME: not clear that this works consistently across optimizers
        # FIXME: hard-coded naming convention sucks and could absolutely break?
        i = 0
        for group in optimizer.state_dict()['param_groups']:
            for key in group['params']:
                tensor = optimizer.state_dict()['state'][key]['momentum_buffer']
                # print("Before value: {}".format(optimizer.state_dict()['state'][key]['momentum_buffer']))
                optimizer.state_dict()['state'][key]['momentum_buffer'] = torch.Tensor(tensor_dict['__opt_{}'.format(i)]).to(tensor.device)
                # print("After value: {}".format(optimizer.state_dict()['state'][key]['momentum_buffer']))
                i += 1
        
    def get_tensor_dict(self):
        # FIXME: should we use self.parameters()??? Unclear if load_state_dict() is better or simple assignment is better
        # for now, state dict gives us names, which is good

        # FIXME: do both and sanity check each time?

        # FIXME: can this have values other than the tensors????
        state = self.state_dict()
        for k, v in state.items():
            state[k] = v.cpu().numpy() # get as a numpy array
        return {**state, **self.get_optimizer_tensors()}

    def set_tensor_dict(self, tensor_dict):
        # FIXME: should we use self.parameters()??? Unclear if load_state_dict() is better or simple assignment is better
        # for now, state dict gives us names, which is good
        
        # FIXME: do both and sanity check each time?

        # get the model state so that we can determine the correct tensor values/device placements
        model_state = self.state_dict()

        new_state = {}
        for k, v in model_state.items():
            new_state[k] = torch.Tensor(tensor_dict[k]).to(v.device)

        # set model state
        self.load_state_dict(new_state)

        # next we have the optimizer state
        self.set_optimizer_tensors(tensor_dict)


In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
import torch.nn.functional as F
import torch.optim as optim

from tfedlrn.datasets import load_dataset
# from tfedlrn.collaborator.pytorchflmodel import PyTorchFLModel


def dice_coef(pred, target, smoothing=1.0):    
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target).sum(dim=(1, 2, 3))
    
    return ((2 * intersection + smoothing) / (union + smoothing)).mean()


def dice_coef_loss(pred, target, smoothing=1.0):    
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target).sum(dim=(1, 2, 3))
    
    term1 = -torch.log(2 * intersection + smoothing)
    term2 = torch.log(union + smoothing)
    
    return term1.mean() + term2.mean()


class PyTorch2DUNet(PyTorchFLModel):

    def __init__(self, device, train_loader=None, val_loader=None, optimizer='SGD'):
        super(PyTorch2DUNet, self).__init__()

        self.device = device
        self.init_data_pipeline(train_loader, val_loader)
        self.init_network(device)
        self.init_optimizer(optimizer)
        
    def create_loader(self, X, y, **kwargs):
        tX = torch.stack([torch.Tensor(i) for i in X])
        ty = torch.stack([torch.Tensor(i) for i in y])
        return torch.utils.data.DataLoader(torch.utils.data.TensorDataset(tX, ty), **kwargs)

    # FIXME: brats loading
    def init_data_pipeline(self, train_loader, val_loader):
        if train_loader is None or val_loader is None:
            # load all the institutions
            data_by_institution = [load_dataset('BraTS17_institution',
                                                channels_first=True,
                                                institution=i) for i in range(10)]
            data_by_type = zip(*data_by_institution)
            data_by_type = [np.concatenate(d) for d in data_by_type]
            X_train, y_train, X_val, y_val = data_by_type

        #TODO: Replace both shuffle=False below with shuffle=True, currently testing and so want reproducibility
        if train_loader is None:
            self.train_loader = self.create_loader(X_train, y_train, batch_size=64, shuffle=False)
        else:
            self.train_loader = train_loader

        if val_loader is None:
            self.val_loader = self.create_loader(X_val, y_val, batch_size=64, shuffle=False)
        else:
            self.val_loader = val_loader
            
    def init_network(self,
                     device,
                     initial_channels=1,
                     depth_per_side=5,
                     initial_filters=32):

        f = initial_filters
        
        # store our depth for our forward function
        self.depth_per_side = 5
        
        # parameter-less layers
        self.dropout = nn.Dropout(p=0.2)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
                
        # initial down layers
        conv_down_a = [nn.Conv2d(initial_channels, f, 3, padding=1)]
        conv_down_b = [nn.Conv2d(f, f, 3, padding=1)]
                
        # rest of the layers going down
        for i in range(1, depth_per_side):
            f *= 2
            conv_down_a.append(nn.Conv2d(f // 2, f, 3, padding=1))
            conv_down_b.append(nn.Conv2d(f, f, 3, padding=1))
            
        # going up, do all but the last layer
        conv_up_a = []
        conv_up_b = []
        for _ in range(depth_per_side-1):
            f //= 2
            # triple input channels due to skip connections
            conv_up_a.append(nn.Conv2d(f*3, f, 3, padding=1))
            conv_up_b.append(nn.Conv2d(f, f, 3, padding=1))
            
        # do the last layer
        self.conv_out = nn.Conv2d(f, 1, 1, padding=0)
        
        # all up/down layers need to to become fields of this object
        for i, (a, b) in enumerate(zip(conv_down_a, conv_down_b)):
            setattr(self, 'conv_down_{}a'.format(i+1), a)
            setattr(self, 'conv_down_{}b'.format(i+1), b)
            
        # all up/down layers need to to become fields of this object
        for i, (a, b) in enumerate(zip(conv_up_a, conv_up_b)):
            setattr(self, 'conv_up_{}a'.format(i+1), a)
            setattr(self, 'conv_up_{}b'.format(i+1), b)
        
        # send this to the device
        self.to(device)
        
    def forward(self, x):
        
        # gather up our up and down layer members for easier processing
        conv_down_a = [getattr(self, 'conv_down_{}a'.format(i+1)) for i in range(self.depth_per_side)]
        conv_down_b = [getattr(self, 'conv_down_{}b'.format(i+1)) for i in range(self.depth_per_side)]
        conv_up_a = [getattr(self, 'conv_up_{}a'.format(i+1)) for i in range(self.depth_per_side - 1)]
        conv_up_b = [getattr(self, 'conv_up_{}b'.format(i+1)) for i in range(self.depth_per_side - 1)]
        
        # we concatenate the outputs from the b layers
        concat_me = []
        pool = x

        # going down, wire each pair and then pool except the last
        for a, b in zip(conv_down_a, conv_down_b):
            out_down = F.relu(b(F.relu(a(pool))))
            # if not the last down b layer, pool it and add it to the concat list
            if b != conv_down_b[-1]:
                concat_me.append(out_down)
                pool = self.maxpool(out_down) # feed the pool into the next layer
        
        # reverse the concat_me layers
        concat_me = concat_me[::-1]

        # we start going up with the b (not-pooled) from previous layer
        in_up = out_down

        # going up, we need to zip a, b and concat_me
        for a, b, c in zip(conv_up_a, conv_up_b, concat_me):
            up = torch.cat([self.upsample(in_up), c], dim=1)
            in_up = F.relu(b(F.relu(a(up))))
        
        # finally, return the output
        return torch.sigmoid(self.conv_out(in_up))

    def init_optimizer(self, optimizer='SGD'):
        if optimizer == 'SGD':
            self.optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9)
        elif optimizer == 'RMSprop':
            self.optimizer = optim.RMSprop(self.parameters(), lr=1e-5, momentum=0.9)
        elif optimizer == 'Adam':
            self.optimizer = optim.Adam(self.parameters(), lr=1e-5)
        else:
            raise ValueError()

    def get_optimizer(self):
        return self.optimizer

    def train_partial_epoch(self):
        # set to "training" mode
        self.train()
        
        losses = []

        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            
            # TODO: Remove below - storing cpu data to inspect below
            cpu_data, cpu_target = data, target
            
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            output = self(data)
            loss = dice_coef_loss(output, target, smoothing=32.0)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach().cpu().numpy())
            
            #TODO: Remove below- just performing one batch of training now for testing below
            info = "Sum of batch is data: {}, target: {}".format(np.sum(cpu_data.numpy(), axis=0), 
                                                                 np.sum(cpu_target.numpy(), axis=0))
            break
            
        return np.mean(losses), info

    def get_training_data_size(self):
        return len(self.train_loader.dataset)

    def validate(self):
        self.eval()
        dice = 0
        total_samples = 0

        with torch.no_grad():
            counter = 0
            for data, target in self.val_loader:
                samples = target.shape[0]
                total_samples += samples
                data, target = data.to(self.device), target.to(self.device)
                output = self(data)
                dice += dice_coef(output, target).cpu().numpy() * samples
                counter += 1
                if counter == 2:
                    break
        return dice / total_samples

    def get_validation_data_size(self):
        return len(self.val_loader.dataset)


In [6]:
device = torch.device("cuda")

In [7]:
#######################################################################################################

In [8]:
############### EXPLORING OPTIMIZER PARAMETERS (SAVING AND RESTORING over processes) ###################

In [9]:
#######################################################################################################

In [10]:
cnn = PyTorch2DUNet(device=device, optimizer='SGD')

In [11]:
# Testing that the data used for training is reproducible by looking at its sum

In [12]:
cnn.train_partial_epoch()

(5.55046,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [13]:
cnn.train_partial_epoch()

(5.5501776,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [14]:
# testing that validation is reproducible for a given model

In [15]:
cnn.validate()

0.03433784192020539

In [16]:
cnn.validate()

0.03433784192020539

In [17]:
############################################################################################
# testing restoring model and optimizer, should result in training to the same validation ##
############################################################################################

In [18]:
cnn.validate()

0.03433784192020539

In [19]:
# save model
model_weights = cnn.get_tensor_dict()

In [20]:
# train for some more, then see that model is different now

In [21]:
cnn.train_partial_epoch()

(5.5496416,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [22]:
# see valiation is not the same now
cnn.validate()

0.03433564520673826

In [23]:
# restore saved model
cnn.set_tensor_dict(model_weights)

In [24]:
# see validation is back to previously observed for saved model
cnn.validate()

0.03433784192020539

In [25]:
# see if grabbing the full state again, it gives the same thing as we had saved before

In [26]:
model_weights_2 = cnn.get_tensor_dict()

In [27]:
bool_array = np.array([])
for key in model_weights:
    np.append(bool_array, [np.all(model_weights[key] == model_weights_2[key])])
np.all(bool_array)


True

In [28]:
# now train again and see you get back to the place you did before after running 'train_partial_epoch' once

In [29]:
cnn.train_partial_epoch()

(5.5496416,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [30]:
cnn.validate()

0.03433564520673826

In [31]:
#######################################

In [32]:
# ----now test that restoring can happen across processes

In [33]:
########################################

In [34]:
# saving model to disk

In [35]:
import pickle
with open('saved_model.pkl', 'wb') as file:
    pickle.dump(model_weights, file)

In [36]:
###### RESTART KERNEL HERE  then run the top cells of this workbook up to device intialization ###############

In [7]:
###### Then return to this point and run cells below ####################

In [8]:
#initiate model and train a bit, then restore model from disk

In [9]:
cnn = PyTorch2DUNet(device)

In [10]:
cnn.train_partial_epoch()

(5.4738846,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [11]:
# see valiation is not the same as the ones before
cnn.validate()

0.034504899696912616

In [12]:
# get model weights from disk

In [13]:
with open('saved_model.pkl', 'rb') as file:
    model_weights = pickle.load(file)

In [14]:
# restore model to saved values
cnn.set_tensor_dict(model_weights)

In [15]:
# see validation is back to previously observed for saved model
cnn.validate()

0.03433784192020539

In [16]:
# see if grabbing the full state again, it gives the same thing as we had saved before

In [17]:
model_weights_2 = cnn.get_tensor_dict()

In [18]:
bool_array = np.array([])
for key in model_weights:
    np.append(bool_array, [np.all(model_weights[key] == model_weights_2[key])])
np.all(bool_array)


True

In [19]:
# now train to get a new validation, and see that it matches what we had after training once before
cnn.train_partial_epoch()

(5.5496416,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [20]:

cnn.validate()

0.03433564520673826

In [None]:
############# Solution needs to be more general to acommodate different optimizers #############

In [None]:
####### For example, the keys for Adam are different ('momentum_buffer' throws a key error) #############