In [None]:
# This script takes work from Pytorch MNIST and UNet Dev.ipynb, takes the UNet portion only, and makes 
# training and validation repeatable in order
# to test a method for saving and setting optimizer state (momentum, etc.) that remains consistent
# even when the restoring is done after restarting the kernel.

In [None]:
# run the cells below (up to ######) when restarting kernel

In [1]:
import abc

class FLModel(metaclass=abc.ABCMeta):

    @abc.abstractmethod
    def get_tensor_dict(self):
        """Returns all parameters for aggregation, including optimizer parameters, if appropriate"""
        pass

    @abc.abstractmethod
    def set_tensor_dict(self, tensor_dict):
        """Returns all parameters for aggregation, including optimizer parameters, if appropriate"""
        pass

    @abc.abstractmethod
    def train_epoch(self):
        pass

    @abc.abstractmethod
    def get_training_data_size(self):
        pass

    @abc.abstractmethod
    def validate(self):
        pass

    @abc.abstractmethod
    def get_validation_data_size(self):
        pass


In [2]:
import pickle
import glob
import os
import socket

import numpy as np
from math import ceil


def _get_dataset_func_map():
    return {
        'mnist': load_mnist,
#         'fashion-mnist': load_fashion_mnist,
#         'pubfig83': load_pubfig83,
#         'cifar10': load_cifar10,
#         'cifar20': load_cifar20,
#         'cifar100': load_cifar100,
#         'bsm': load_bsm,
#         'BraTS17': load_BraTS17,
    }


def get_dataset_list():
    return list(_get_dataset_func_map().keys())


def load_dataset(dataset, **kwargs):
    if dataset not in get_dataset_list():
        raise ValueError("Dataset {} not in list of datasets {get_dataset_list()}".format(dataset))
    return _get_dataset_func_map()[dataset](**kwargs)


def _get_dataset_dir(server='edwardsb-Z270x-UD5'):
    if server is None:
        server = socket.gethostname()
    server_to_path = {'spr-gpu01': os.path.join('/', 'raid', 'datasets'),
                      'edwardsb-Z270X-UD5': os.path.join('/', 'data'),
                      'msheller-ubuntu': os.path.join('/', 'home', 'msheller', 'datasets')}
    return server_to_path[server]


def _unpickle(file):
    with open(file, 'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
    return d


def _read_mnist(path, **kwargs):
    X_train, y_train = _read_mnist_kind(path, kind='train', **kwargs)
    X_test, y_test = _read_mnist_kind(path, kind='t10k', **kwargs)

    return X_train, y_train, X_test, y_test


# from https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py
def _read_mnist_kind(path, kind='train', one_hot=True, **kwargs):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    images = images.astype(float) / 255
    if one_hot:
        labels = _one_hot(labels.astype(np.int), 10)

    return images, labels


def load_mnist(**kwargs):
    path = os.path.join(_get_dataset_dir(), 'mnist', 'input_data')
    return _read_mnist(path, **kwargs)


def load_fashion_mnist(**kwargs):
    path = os.path.join(_get_dataset_dir(), 'fashion-mnist')
    return _read_mnist(path, **kwargs)


def _one_hot(y, n):
    return np.eye(n)[y]

In [3]:
import torch
import torch.nn as nn


class PyTorchFLModel(FLModel, nn.Module):
    """WIP code. Goal is to simplify porting a model to this framework.
    Currently, this creates a placeholder and assign op for every variable, which grows the graph considerably.
    Also, the abstraction for the tf.session isn't ideal yet."""

    def __init__(self):
        # calls nn.Module init
        super(PyTorchFLModel, self).__init__()

    @abc.abstractmethod
    def get_optimizer(self):
        pass

    def get_optimizer_tensors(self):
        optimizer = self.get_optimizer()

        tensor_dict = {}

        # NOTE: this gave inconsistent orderings across collaborators, so does not work
        # state = optimizer.state_dict()['state']

        # # FIXME: this is really fragile. Need to understand what could change here
        # for i, sk in enumerate(state.keys()):
        #     if isinstance(state[sk], dict):
        #         for k, v in state[sk].items():
        #             if isinstance(v, torch.Tensor):
        #                 tensor_dict['{}_{}'.format(i, k)] = v.cpu().numpy()

        # FIXME: not clear that this works consistently across optimizers
        # FIXME: hard-coded naming convention sucks and could absolutely break
        i = 0
        for group in optimizer.param_groups:
            for p in group['params']:
                tensor_dict['__opt_{}'.format(i)] = p.detach().cpu().numpy()
                i += 1

        return tensor_dict
                    
    def set_optimizer_tensors(self, tensor_dict):
        optimizer = self.get_optimizer()

        # NOTE: the state dict ordering wasn't consistent. We'd like to use load_state_dict rather than
        # directly setting the tensors, if possible, but it's not clear that we can
#         state = optimizer.state_dict()

#         # FIXME: this is really fragile. Need to understand what could change here
#         for i, sk in enumerate(state['state'].keys()):
#             if isinstance(state['state'][sk], dict):
#                 for k, v in state['state'][sk].items():
#                     if isinstance(v, torch.Tensor):
#                         key = '{}_{}'.format(i, k)
                        
#                         if key not in tensor_dict:
#                             raise ValueError('{} not in keys: {}'.format(key, list(tensor_dict.keys())))
                        
#                         state['state'][sk][k] = torch.Tensor(tensor_dict[key]).to(v.device)
#         optimizer.load_state_dict(state)
        
        # FIXME: not clear that this works consistently across optimizers
        # FIXME: hard-coded naming convention sucks and could absolutely break
        i = 0
        for group in optimizer.param_groups:
            for idx, p in enumerate(group['params']):
                old = group['params'][idx]
                new = torch.Tensor(tensor_dict['__opt_{}'.format(i)]).to(old.device)

    def get_tensor_dict(self):
        # FIXME: should we use self.parameters()??? Unclear if load_state_dict() is better or simple assignment is better
        # for now, state dict gives us names, which is good

        # FIXME: do both and sanity check each time?

        # FIXME: can this have values other than the tensors????
        state = self.state_dict()
        for k, v in state.items():
            state[k] = v.cpu().numpy() # get as a numpy array
        return {**state, **self.get_optimizer_tensors()}

    def set_tensor_dict(self, tensor_dict):
        # FIXME: should we use self.parameters()??? Unclear if load_state_dict() is better or simple assignment is better
        # for now, state dict gives us names, which is good
        
        # FIXME: do both and sanity check each time?

        # get the model state so that we can determine the correct tensor values/device placements
        model_state = self.state_dict()

        new_state = {}
        for k, v in model_state.items():
            new_state[k] = torch.Tensor(tensor_dict[k]).to(v.device)

        # set model state
        self.load_state_dict(new_state)

        # next we have the optimizer state
        self.set_optimizer_tensors(tensor_dict)


In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
import torch.nn.functional as F
import torch.optim as optim

class PyTorchMNISTCNN(PyTorchFLModel):

    def __init__(self, device, train_loader=None, val_loader=None):
        super(PyTorchMNISTCNN, self).__init__()

        self.device = device
        self.init_data_pipeline(train_loader, val_loader)
        self.init_network(device)
        self.init_optimizer()

    def create_loader(self, X, y, **kwargs):
        tX = torch.stack([torch.Tensor(i) for i in X])
        ty = torch.stack([torch.Tensor(i) for i in y])
        return torch.utils.data.DataLoader(torch.utils.data.TensorDataset(tX, ty), **kwargs)

    def init_data_pipeline(self, train_loader, val_loader):
        if train_loader is None or val_loader is None:
            X_train, y_train, X_val, y_val = load_dataset('mnist')
            X_train = X_train.reshape([-1, 1, 28, 28])
            X_val = X_val.reshape([-1, 1, 28, 28])

        if train_loader is None:
            self.train_loader = self.create_loader(X_train, y_train, batch_size=64, shuffle=True)
        else:
            self.train_loader = train_loader

        if val_loader is None:
            self.val_loader = self.create_loader(X_val, y_val, batch_size=64, shuffle=True)
        else:
            self.val_loader = val_loader

    def init_network(self, device):
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        self.to(device)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def init_optimizer(self):
        self.optimizer = optim.SGD(self.parameters(), lr=0.01, momentum=0.5)

    def get_optimizer(self):
        return self.optimizer

    def train_epoch(self):
        # set to "training" mode
        self.train()
        
        losses = []

        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device, dtype=torch.int64)
            self.optimizer.zero_grad()
            output = self(data)
            loss = F.cross_entropy(output, torch.max(target, 1)[1])
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach().cpu().numpy())
            
        return np.mean(losses)

    def get_training_data_size(self):
        return len(self.train_loader.dataset)

    def validate(self):
        self.eval()
        correct = 0

        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device, dtype=torch.int64)
                output = self(data)
                pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                target = torch.max(target, 1)[1]
                # FIXME: there has to be a better way than exhaustive eq then diagonal
                eq = pred.eq(target).diag().sum().cpu().numpy()
                correct += eq

        return correct / self.get_validation_data_size()

    def get_validation_data_size(self):
        return len(self.val_loader.dataset)


In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
import torch.nn.functional as F
import torch.optim as optim

from tfedlrn.datasets import load_dataset
from tfedlrn.collaborator.pytorchflmodel import PyTorchFLModel


def dice_coef(pred, target, smoothing=1.0):    
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target).sum(dim=(1, 2, 3))
    
    return ((2 * intersection + smoothing) / (union + smoothing)).mean()


def dice_coef_loss(pred, target, smoothing=1.0):    
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target).sum(dim=(1, 2, 3))
    
    term1 = -torch.log(2 * intersection + smoothing)
    term2 = torch.log(union + smoothing)
    
    return term1.mean() + term2.mean()


class PyTorch2DUNet(PyTorchFLModel):

    def __init__(self, device, train_loader=None, val_loader=None, optimizer='SGD'):
        super(PyTorch2DUNet, self).__init__()

        self.device = device
        self.init_data_pipeline(train_loader, val_loader)
        self.init_network(device)
        self.init_optimizer(optimizer)
        
    def create_loader(self, X, y, **kwargs):
        tX = torch.stack([torch.Tensor(i) for i in X])
        ty = torch.stack([torch.Tensor(i) for i in y])
        return torch.utils.data.DataLoader(torch.utils.data.TensorDataset(tX, ty), **kwargs)

    # FIXME: brats loading
    def init_data_pipeline(self, train_loader, val_loader):
        if train_loader is None or val_loader is None:
            # load all the institutions
            data_by_institution = [load_dataset('BraTS17_institution',
                                                channels_first=True,
                                                institution=i) for i in range(10)]
            data_by_type = zip(*data_by_institution)
            data_by_type = [np.concatenate(d) for d in data_by_type]
            X_train, y_train, X_val, y_val = data_by_type

        #TODO: Replace both shuffle=False below with shuffle=True, currently testing and so want reproducibility
        if train_loader is None:
            self.train_loader = self.create_loader(X_train, y_train, batch_size=64, shuffle=False)
        else:
            self.train_loader = train_loader

        if val_loader is None:
            self.val_loader = self.create_loader(X_val, y_val, batch_size=64, shuffle=False)
        else:
            self.val_loader = val_loader
            
    def init_network(self,
                     device,
                     initial_channels=1,
                     depth_per_side=5,
                     initial_filters=32):

        f = initial_filters
        
        # store our depth for our forward function
        self.depth_per_side = 5
        
        # parameter-less layers
        self.dropout = nn.Dropout(p=0.2)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
                
        # initial down layers
        conv_down_a = [nn.Conv2d(initial_channels, f, 3, padding=1)]
        conv_down_b = [nn.Conv2d(f, f, 3, padding=1)]
                
        # rest of the layers going down
        for i in range(1, depth_per_side):
            f *= 2
            conv_down_a.append(nn.Conv2d(f // 2, f, 3, padding=1))
            conv_down_b.append(nn.Conv2d(f, f, 3, padding=1))
            
        # going up, do all but the last layer
        conv_up_a = []
        conv_up_b = []
        for _ in range(depth_per_side-1):
            f //= 2
            # triple input channels due to skip connections
            conv_up_a.append(nn.Conv2d(f*3, f, 3, padding=1))
            conv_up_b.append(nn.Conv2d(f, f, 3, padding=1))
            
        # do the last layer
        self.conv_out = nn.Conv2d(f, 1, 1, padding=0)
        
        # all up/down layers need to to become fields of this object
        for i, (a, b) in enumerate(zip(conv_down_a, conv_down_b)):
            setattr(self, 'conv_down_{}a'.format(i+1), a)
            setattr(self, 'conv_down_{}b'.format(i+1), b)
            
        # all up/down layers need to to become fields of this object
        for i, (a, b) in enumerate(zip(conv_up_a, conv_up_b)):
            setattr(self, 'conv_up_{}a'.format(i+1), a)
            setattr(self, 'conv_up_{}b'.format(i+1), b)
        
        # send this to the device
        self.to(device)
        
    def forward(self, x):
        
        # gather up our up and down layer members for easier processing
        conv_down_a = [getattr(self, 'conv_down_{}a'.format(i+1)) for i in range(self.depth_per_side)]
        conv_down_b = [getattr(self, 'conv_down_{}b'.format(i+1)) for i in range(self.depth_per_side)]
        conv_up_a = [getattr(self, 'conv_up_{}a'.format(i+1)) for i in range(self.depth_per_side - 1)]
        conv_up_b = [getattr(self, 'conv_up_{}b'.format(i+1)) for i in range(self.depth_per_side - 1)]
        
        # we concatenate the outputs from the b layers
        concat_me = []
        pool = x

        # going down, wire each pair and then pool except the last
        for a, b in zip(conv_down_a, conv_down_b):
            out_down = F.relu(b(F.relu(a(pool))))
            # if not the last down b layer, pool it and add it to the concat list
            if b != conv_down_b[-1]:
                concat_me.append(out_down)
                pool = self.maxpool(out_down) # feed the pool into the next layer
        
        # reverse the concat_me layers
        concat_me = concat_me[::-1]

        # we start going up with the b (not-pooled) from previous layer
        in_up = out_down

        # going up, we need to zip a, b and concat_me
        for a, b, c in zip(conv_up_a, conv_up_b, concat_me):
            up = torch.cat([self.upsample(in_up), c], dim=1)
            in_up = F.relu(b(F.relu(a(up))))
        
        # finally, return the output
        return torch.sigmoid(self.conv_out(in_up))

    def init_optimizer(self, optimizer='SGD'):
        if optimizer == 'SGD':
            self.optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9)
        elif optimizer == 'RMSprop':
            self.optimizer = optim.RMSprop(self.parameters(), lr=1e-5, momentum=0.9)
        elif optimizer == 'Adam':
            self.optimizer = optim.Adam(self.parameters(), lr=1e-5)
        else:
            raise ValueError()

    def get_optimizer(self):
        return self.optimizer

    def train_epoch(self):
        # set to "training" mode
        self.train()
        
        losses = []

        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            
            # TODO: Remove below - storing cpu data to inspect below
            cpu_data, cpu_target = data, target
            
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            output = self(data)
            loss = dice_coef_loss(output, target, smoothing=32.0)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach().cpu().numpy())
            
            #TODO: Remove below- just performing one batch of training now for testing below
            info = "Sum of batch is data: {}, target: {}".format(np.sum(cpu_data.numpy(), axis=0), 
                                                                 np.sum(cpu_target.numpy(), axis=0))
            break
            
        return np.mean(losses), info

    def get_training_data_size(self):
        return len(self.train_loader.dataset)

    def validate(self):
        self.eval()
        dice = 0
        total_samples = 0

        with torch.no_grad():
            counter = 0
            for data, target in self.val_loader:
                samples = target.shape[0]
                total_samples += samples
                data, target = data.to(self.device), target.to(self.device)
                output = self(data)
                dice += dice_coef(output, target).cpu().numpy() * samples
                counter += 1
                if counter == 2:
                    break
        return dice / total_samples

    def get_validation_data_size(self):
        return len(self.val_loader.dataset)


In [10]:
device = torch.device("cuda")

In [None]:
#########################################################################################

In [11]:
############### EXPLORING OPTIMIZER PARAMETERS (SAVING AND RESTORING) ###################

In [None]:
#########################################################################################

In [12]:
cnn = PyTorch2DUNet(device)

In [13]:
# Testing that the data used for training is reproducible

In [18]:
cnn.train_epoch()

(5.480192,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [14]:
cnn.train_epoch()

(5.5863986,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [15]:
# testing that validation is reproducible

In [18]:
cnn.validate()

0.034646598491235636

In [19]:
cnn.validate()

0.034646598491235636

In [21]:
cnn.optimizer.state_dict().keys()

dict_keys(['state', 'param_groups'])

In [22]:
len(cnn.optimizer.state_dict()['param_groups'])

1

In [23]:
cnn.optimizer.state_dict()['param_groups'][0]

{'dampening': 0,
 'lr': 0.001,
 'momentum': 0.9,
 'nesterov': False,
 'params': [140698930495352,
  140698930495424,
  140698930492472,
  140698930492544,
  140698930492616,
  140698930492688,
  140698930492832,
  140698930492904,
  140698930492976,
  140698930493048,
  140698930493120,
  140698930493192,
  140698930493264,
  140698930493336,
  140698930493408,
  140698930493480,
  140698930493552,
  140698930493624,
  140698930493696,
  140698930493768,
  140698930493840,
  140698930493912,
  140698930493984,
  140698930494056,
  140698930494128,
  140698930494200,
  140698930494272,
  140698930494344,
  140698930494416,
  140698930494488,
  140698930494560,
  140698930494632,
  140698930494704,
  140698930494776,
  140698930494848,
  140698930494920,
  140698930494992,
  140698930495064],
 'weight_decay': 0}

In [24]:
# note though the keys are the same, they are not in the same order
# when saving and restoring, we will traverse in the order given by the list object to rely on python list ordering

In [25]:
cnn.optimizer.state_dict()['state'].keys()

dict_keys([140698930495424, 140698930492544, 140698930493480, 140698930494416, 140698930493408, 140698930494920, 140698930493048, 140698930495064, 140698930492688, 140698930494272, 140698930493840, 140698930493624, 140698930494992, 140698930492976, 140698930493336, 140698930494632, 140698930492616, 140698930493984, 140698930493264, 140698930493912, 140698930492904, 140698930494056, 140698930493552, 140698930494200, 140698930493192, 140698930493768, 140698930494704, 140698930494344, 140698930494848, 140698930492832, 140698930493696, 140698930494488, 140698930492472, 140698930495352, 140698930494128, 140698930493120, 140698930494560, 140698930494776])

In [26]:
cnn.optimizer.state_dict()['param_groups'][0]['params']

[140698930495352,
 140698930495424,
 140698930492472,
 140698930492544,
 140698930492616,
 140698930492688,
 140698930492832,
 140698930492904,
 140698930492976,
 140698930493048,
 140698930493120,
 140698930493192,
 140698930493264,
 140698930493336,
 140698930493408,
 140698930493480,
 140698930493552,
 140698930493624,
 140698930493696,
 140698930493768,
 140698930493840,
 140698930493912,
 140698930493984,
 140698930494056,
 140698930494128,
 140698930494200,
 140698930494272,
 140698930494344,
 140698930494416,
 140698930494488,
 140698930494560,
 140698930494632,
 140698930494704,
 140698930494776,
 140698930494848,
 140698930494920,
 140698930494992,
 140698930495064]

In [19]:
# getting the keys to be used for the dictionary: cnn.optimizer.state_dict()['state']
optimizer_keys = []
for group in cnn.optimizer.state_dict()['param_groups']:
    for key in group['params']:
        optimizer_keys.append(key)
print("there are {} optimizer keys.".format(len(optimizer_keys)))

there are 38 optimizer keys.


In [16]:
cnn.optimizer.state_dict()['state'][optimizer_keys[0]]

{'momentum_buffer': tensor([[[[4.2651e-05]],
 
          [[1.5496e-02]],
 
          [[5.5317e-03]],
 
          [[3.3534e-02]],
 
          [[5.5599e-03]],
 
          [[1.1801e-02]],
 
          [[1.9776e-02]],
 
          [[1.0708e-03]],
 
          [[3.6289e-08]],
 
          [[9.2248e-07]],
 
          [[3.3608e-02]],
 
          [[8.8546e-03]],
 
          [[1.6957e-02]],
 
          [[7.5179e-08]],
 
          [[0.0000e+00]],
 
          [[2.2632e-03]],
 
          [[1.4548e-02]],
 
          [[8.5252e-05]],
 
          [[3.7501e-06]],
 
          [[2.5499e-08]],
 
          [[5.6561e-03]],
 
          [[3.2854e-09]],
 
          [[1.8336e-02]],
 
          [[7.0124e-07]],
 
          [[2.1024e-04]],
 
          [[2.0376e-09]],
 
          [[2.9095e-03]],
 
          [[8.5573e-03]],
 
          [[2.9018e-02]],
 
          [[0.0000e+00]],
 
          [[6.2537e-04]],
 
          [[1.6738e-10]]]], device='cuda:0')}

In [17]:
# chekcing to see that 'momentum_buffer' is all that is there for this optimizer
[cnn.optimizer.state_dict()['state'][key].keys() for key in cnn.optimizer.state_dict()['param_groups'][0]['params']]

[dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys(['momentum_buffer']),
 dict_keys

In [None]:
# inspecting the order of sizes of the momentum buffers when using 'optimizer_keys'
# Note that you can run this over and over after restarting the kernel, but I perform
# that test more sytematically at the end of this notebook

In [20]:
len(cnn.optimizer.state_dict()['state'].keys())

38

In [21]:
cnn.optimizer.state_dict()['state'][optimizer_keys[0]]['momentum_buffer'].shape

torch.Size([1, 32, 1, 1])

In [22]:
cnn.optimizer.state_dict()['state'][optimizer_keys[1]]['momentum_buffer'].shape

torch.Size([1])

In [23]:
cnn.optimizer.state_dict()['state'][optimizer_keys[2]]['momentum_buffer'].shape

torch.Size([32, 1, 3, 3])

In [24]:
cnn.optimizer.state_dict()['state'][optimizer_keys[3]]['momentum_buffer'].shape

torch.Size([32])

In [25]:
cnn.optimizer.state_dict()['state'][optimizer_keys[4]]['momentum_buffer'].shape

torch.Size([32, 32, 3, 3])

In [26]:
cnn.optimizer.state_dict()['state'][optimizer_keys[5]]['momentum_buffer'].shape

torch.Size([32])

In [39]:
# testing restoring model and optimizer, should result in training to the same validation

In [41]:
cnn.validate()

0.034646598491235636

In [42]:
# save optimizer state
optimizer_state = {}
i = 0
for group in cnn.optimizer.state_dict()['param_groups']:
    for key in group['params']:
        optimizer_state[i] = cnn.optimizer.state_dict()['state'][key]['momentum_buffer'].detach().cpu().numpy()
        i += 1

In [43]:
# save model
model_weights = cnn.get_tensor_dict()

In [44]:
cnn.train_epoch()

(5.5859056,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [45]:
# see valiation is not the same now
cnn.validate()

0.03464584346511401

In [46]:
# restore saved model
cnn.set_tensor_dict(model_weights)

In [47]:
# see validation is back to previously observed for saved model
cnn.validate()

0.034646598491235636

In [None]:
# now look at a piece of the optimizer state to see that it does not match what we  had saved

In [48]:
cnn.optimizer.state_dict()['state'][optimizer_keys[0]]

{'momentum_buffer': tensor([[[[1.4488e-02]],
 
          [[9.3489e-06]],
 
          [[0.0000e+00]],
 
          [[3.2723e-02]],
 
          [[2.8575e-02]],
 
          [[1.8181e-04]],
 
          [[3.0210e-02]],
 
          [[4.8071e-05]],
 
          [[0.0000e+00]],
 
          [[8.3902e-04]],
 
          [[8.1692e-05]],
 
          [[5.9734e-02]],
 
          [[6.2871e-02]],
 
          [[5.2276e-02]],
 
          [[3.1715e-02]],
 
          [[5.3174e-02]],
 
          [[7.0356e-02]],
 
          [[0.0000e+00]],
 
          [[9.3755e-05]],
 
          [[1.5996e-02]],
 
          [[8.5440e-04]],
 
          [[7.3476e-02]],
 
          [[9.3773e-04]],
 
          [[5.3554e-02]],
 
          [[3.4916e-02]],
 
          [[0.0000e+00]],
 
          [[4.1385e-03]],
 
          [[1.3060e-02]],
 
          [[3.6511e-02]],
 
          [[1.0374e-06]],
 
          [[2.2838e-05]],
 
          [[5.2651e-03]]]], device='cuda:0')}

In [49]:
optimizer_state[0]

array([[[[1.01544615e-02]],

        [[6.54031192e-06]],

        [[0.00000000e+00]],

        [[2.29606144e-02]],

        [[2.00559720e-02]],

        [[1.27975727e-04]],

        [[2.12178156e-02]],

        [[3.37234596e-05]],

        [[0.00000000e+00]],

        [[5.89032075e-04]],

        [[5.73100406e-05]],

        [[4.18476686e-02]],

        [[4.39579785e-02]],

        [[3.66692208e-02]],

        [[2.23007090e-02]],

        [[3.73309031e-02]],

        [[4.93711829e-02]],

        [[0.00000000e+00]],

        [[6.55975964e-05]],

        [[1.12915374e-02]],

        [[5.99570572e-04]],

        [[5.15785143e-02]],

        [[6.58715959e-04]],

        [[3.75182629e-02]],

        [[2.44858358e-02]],

        [[0.00000000e+00]],

        [[2.92675500e-03]],

        [[9.18509439e-03]],

        [[2.56514139e-02]],

        [[7.30412751e-07]],

        [[1.60308919e-05]],

        [[3.68853193e-03]]]], dtype=float32)

In [50]:
# now restore old optizer state

In [51]:
# set optimizer state
i = 0
for group in cnn.optimizer.state_dict()['param_groups']:
    for key in group['params']:
        tensor = cnn.optimizer.state_dict()['state'][key]['momentum_buffer']
        cnn.optimizer.state_dict()['state'][key]['momentum_buffer'] = torch.Tensor(optimizer_state[i]).to(tensor.device)
        i += 1

In [52]:
# now check again to see that this time the actual piece of state matches what we had saved

In [53]:
cnn.optimizer.state_dict()['state'][optimizer_keys[0]]

{'momentum_buffer': tensor([[[[1.0154e-02]],
 
          [[6.5403e-06]],
 
          [[0.0000e+00]],
 
          [[2.2961e-02]],
 
          [[2.0056e-02]],
 
          [[1.2798e-04]],
 
          [[2.1218e-02]],
 
          [[3.3723e-05]],
 
          [[0.0000e+00]],
 
          [[5.8903e-04]],
 
          [[5.7310e-05]],
 
          [[4.1848e-02]],
 
          [[4.3958e-02]],
 
          [[3.6669e-02]],
 
          [[2.2301e-02]],
 
          [[3.7331e-02]],
 
          [[4.9371e-02]],
 
          [[0.0000e+00]],
 
          [[6.5598e-05]],
 
          [[1.1292e-02]],
 
          [[5.9957e-04]],
 
          [[5.1579e-02]],
 
          [[6.5872e-04]],
 
          [[3.7518e-02]],
 
          [[2.4486e-02]],
 
          [[0.0000e+00]],
 
          [[2.9268e-03]],
 
          [[9.1851e-03]],
 
          [[2.5651e-02]],
 
          [[7.3041e-07]],
 
          [[1.6031e-05]],
 
          [[3.6885e-03]]]], device='cuda:0')}

In [54]:
# now train to get a new validation
cnn.train_epoch()

(5.5859056,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [56]:

cnn.validate()

0.03464584346511401

In [57]:
# now restore model and optimizer to see that training results in the same validation after training

In [58]:
cnn.set_tensor_dict(model_weights)

In [60]:
# set optimizer state
i = 0
for group in cnn.optimizer.state_dict()['param_groups']:
    for key in group['params']:
        tensor = cnn.optimizer.state_dict()['state'][key]['momentum_buffer']
        cnn.optimizer.state_dict()['state'][key]['momentum_buffer'] = torch.Tensor(optimizer_state[i]).to(tensor.device)
        i += 1

In [61]:
# see that validation has changed back to restored moedl validation
cnn.validate()

0.034646598491235636

In [62]:

cnn.train_epoch()

(5.5859056,
 'Sum of batch is data: [[[-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  ...\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]\n  [-50.2968 -50.2968 -50.2968 ... -50.2968 -50.2968 -50.2968]]], target: [[[0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  ...\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]\n  [0. 0. 0. ... 0. 0. 0.]]]')

In [63]:

cnn.validate()

0.03464584346511401

In [None]:
############### MOVING TO REPRODUCIBILITY OVER PROCESSES (SAVING DURING ONE, RESTORING DURING ANOTHER) ###################

In [11]:
def get_optimizer_key_orders(device):
    model= PyTorch2DUNet(device)
    model.train_epoch()
    # print("model train loss: {}".format(model.train_epoch()))
    # print("model val result: {}".format(model.validate()))
    optimizer_keys = []
    for group in model.optimizer.state_dict()['param_groups']:
        for key in group['params']:
            optimizer_keys.append(key)
    # print("optimizer_keys: {}".format(optimizer_keys))
    # print("state_dict state keys: {}".format(model.optimizer.state_dict()['state'].keys()))
    # print("there are {} optimizer keys.".format(len(optimizer_keys)))
    sizes = [model.optimizer.state_dict()['state'][key]['momentum_buffer'].detach().cpu().numpy().shape 
             for key in optimizer_keys]
    # print("optimizer info array sizes list:\n {}".format(sizes))
    
    return model, optimizer_keys, sizes
    

In [12]:
# testing order of shapes is the same for multiple models created in the same process
shape_list = None
for i in range(20):
    _, _, list =  get_optimizer_key_orders(device)
    if i==0:
        shape_list = list
    print("Is the index {} shape list the same as the index 0 list? {}".format(i, list==shape_list))

Is the index 0 shape list the same as the index 0 list? True
Is the index 1 shape list the same as the index 0 list? True
Is the index 2 shape list the same as the index 0 list? True
Is the index 3 shape list the same as the index 0 list? True
Is the index 4 shape list the same as the index 0 list? True
Is the index 5 shape list the same as the index 0 list? True
Is the index 6 shape list the same as the index 0 list? True
Is the index 7 shape list the same as the index 0 list? True
Is the index 8 shape list the same as the index 0 list? True
Is the index 9 shape list the same as the index 0 list? True
Is the index 10 shape list the same as the index 0 list? True
Is the index 11 shape list the same as the index 0 list? True
Is the index 12 shape list the same as the index 0 list? True
Is the index 13 shape list the same as the index 0 list? True
Is the index 14 shape list the same as the index 0 list? True
Is the index 15 shape list the same as the index 0 list? True
Is the index 16 sh

In [None]:
# saving order to disk

In [13]:
import pickle
with open('brandons_shape_list.pkl', 'wb') as file:
    pickle.dump(shape_list, file)

In [14]:
with open('brandons_shape_list.pkl', 'rb') as file:
    restored_shape_list = pickle.load(file)

In [15]:
restored_shape_list == shape_list

True

In [None]:
#################################### restarting kernel here 
#### running all begining cells and the cell defining get_optimizer_key_orders ###########################

In [13]:
# pulling prvious process' shape list from disk
with open('brandons_shape_list.pkl', 'rb') as file:
    shape_list = pickle.load(file)
for i in range(10):
    _, _, list =  get_optimizer_key_orders(device)
    print("Is the index {} shape list the same as the list from last process? {}".format(i, list==shape_list))

Is the index 0 shape list the same as the list from last process? True
Is the index 1 shape list the same as the list from last process? True
Is the index 2 shape list the same as the list from last process? True
Is the index 3 shape list the same as the list from last process? True
Is the index 4 shape list the same as the list from last process? True
Is the index 5 shape list the same as the list from last process? True
Is the index 6 shape list the same as the list from last process? True
Is the index 7 shape list the same as the list from last process? True
Is the index 8 shape list the same as the list from last process? True
Is the index 9 shape list the same as the list from last process? True


In [None]:
# now trying by getting the keys directly from the dict

In [14]:
def get_optimizer_key_orders_from_dict(device):
    model= PyTorch2DUNet(device)
    model.train_epoch()
    # print("model train loss: {}".format(model.train_epoch()))
    # print("model val result: {}".format(model.validate()))
    optimizer_keys = model.optimizer.state_dict()['state'].keys()
    sizes = [model.optimizer.state_dict()['state'][key]['momentum_buffer'].detach().cpu().numpy().shape 
             for key in optimizer_keys]
    # print("optimizer info array sizes list:\n {}".format(sizes))
    
    return model, optimizer_keys, sizes
    

In [15]:
shape_list = None
for i in range(4):
    _, _, list =  get_optimizer_key_orders_from_dict(device)
    if i==0:
        shape_list = list
    print("Is the index {} shape list the same as the index 0 list? {}".format(i, list==shape_list))

Is the index 0 shape list the same as the index 0 list? True
Is the index 1 shape list the same as the index 0 list? False
Is the index 2 shape list the same as the index 0 list? False
Is the index 3 shape list the same as the index 0 list? False
