In [1]:
"""
This notebook tests model and optimizer state save and restore from
one model instance to another, by inspecting the state using
model.state_dict() for model state and model.optimizer.state_dict()
for optimizer. This is an automated single test that runs a test over
a number of model and optimizer types.
"""

'\nThis notebook tests model and optimizer state save and restore from\none model instance to another, by inspecting the state using\nmodel.state_dict() for model state and model.optimizer.state_dict()\nfor optimizer. This is an automated single test that runs a test over\na number of model and optimizer types.\n'

In [2]:
# Here set the gpu to one that is not ocupied
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [3]:
import pickle
import numpy as np
import torch
from copy import deepcopy

In [4]:
# maybe this is how we can test later, for now I am rewriting code below with the following changes
# No shuffle on train and val
# imports rewritten



# from tfedlrn.collaborator.pytorchmodels.pytorch2dunet import PyTorch2DUNet
# from tfedlrn.collaborator.pytorchmodels.pytorchmnistcnn import PyTorchMNISTCNN

In [5]:
# implementing my own data pipeline initializer in order to enforce reproducibility by
# selecting only one batch for training.

from tfedlrn.datasets import load_dataset
from tfedlrn.collaborator.pytorchmodels.pytorchflutils import pt_create_loader


def init_data_pipelines_no_shuffle(model_type):
    
    batch_size = 64
    
    if model_type == PyTorchMNISTCNN:
        
        X_train, y_train, X_val, y_val = load_dataset('mnist')
        X_train = X_train.reshape([-1, 1, 28, 28])
        X_val = X_val.reshape([-1, 1, 28, 28])
        
        # Here is the key reason for implementing this myself.
        # Reduce to producing only one batch for training.
        X_train, y_train = X_train[:batch_size], y_train[:batch_size]

        train_loader = pt_create_loader(X_train, y_train, batch_size=batch_size, shuffle=True)
        val_loader = pt_create_loader(X_val, y_val, batch_size=batch_size, shuffle=True)

        
    elif model_type == PyTorch2DUNet:
        
        data_by_institution = [load_dataset('BraTS17_institution',
                                                institution=i,
                                                channels_first=True) for i in range(10)]
        data_by_type = zip(*data_by_institution)
        data_by_type = [np.concatenate(d) for d in data_by_type]
        X_train, y_train, X_val, y_val = data_by_type
        
        # Here is the key reason for implementing this myself.
        # Reduce to producing only one batch for training.
        X_train, y_train = X_train[:batch_size], y_train[:batch_size]
        
        # Also reduce val loader for unet data so as to not take too long for validation
        X_val, y_val = X_val[:batch_size], y_val[:batch_size]

        train_loader = pt_create_loader(X_train, y_train, batch_size=batch_size, shuffle=True)
        val_loader = pt_create_loader(X_val, y_val, batch_size=batch_size, shuffle=True)         
         
    else:
        raise ValueError('This model type not supported.')
               
    return train_loader, val_loader
        


In [6]:
from tfedlrn.collaborator.pytorchmodels.pytorchmnistcnn import PyTorchMNISTCNN
from tfedlrn.collaborator.pytorchmodels.pytorch2dunet import PyTorch2DUNet

In [7]:
def initialize_model(device, lr, model_type, optimizer_type, 
                     no_momentum=False):
    train_loader, val_loader = init_data_pipelines_no_shuffle(model_type)
    if model_type == PyTorchMNISTCNN:
        cnn = PyTorchMNISTCNN(device=device, train_loader = train_loader, 
                              val_loader = val_loader)
    else:
        cnn = model_type(device=device, train_loader = train_loader, 
                         val_loader = val_loader, optimizer=optimizer_type)
    # modifying the learning rate to make a more substantial change when training
    # only on one batch, so as to detect after such training that the 
    # model was not restored correctly
    for group_idx, group in enumerate(cnn.optimizer.__dict__['param_groups']):
        cnn.optimizer.__dict__['param_groups'][group_idx]['lr'] = lr
    cnn.optimizer.__dict__['defaults']['lr'] = lr
    
    # modifying the momentum so as to test SGD with no state needed
    if no_momentum:
        for group_idx, group in enumerate(cnn.optimizer.__dict__['param_groups']):
            cnn.optimizer.__dict__['param_groups'][group_idx]['momentum'] = 0.0
        cnn.optimizer.__dict__['defaults']['momentum'] = 0.0
    
    
    return cnn

In [8]:
def equality_of_weights_and_biases(cnn_1, cnn_2):
    s_1 = cnn_1.state_dict()
    s_2 = cnn_2.state_dict()
    
    bool_over_layers = []
    for key in s_1:
        bool_over_layers.append(bool(torch.all(torch.eq(s_1[key], s_2[key]))))
        
    return np.all(bool_over_layers)


def equality_of_optimizer_state(cnn_1, cnn_2):
    os_1 = cnn_1.optimizer.state_dict()['state']
    os_2 = cnn_2.optimizer.state_dict()['state']
    
    key_groups_1 = [group['params'] for 
             group in cnn_1.optimizer.state_dict()['param_groups']]
    key_groups_2 = [group['params'] for 
             group in cnn_2.optimizer.state_dict()['param_groups']]
    bool_over_groups = []
    for key_group_1, key_group_2 in zip(key_groups_1, key_groups_2):
        for key_1, key_2 in zip(key_group_1, key_group_2):
            subdict_1 = os_1[key_1]
            subdict_2 = os_2[key_2]
            for subkey in subdict_1:
                if subkey == 'step':
                    eq = subdict_1[subkey] == subdict_2[subkey]
                else:
                    tensor_1 = subdict_1[subkey]
                    tensor_2 = subdict_2[subkey]
                    eq = bool(torch.all(torch.eq(tensor_1, tensor_2)))
                bool_over_groups.append(eq)
    return np.all(bool_over_groups)


def equality_of_models(cnn_1, cnn_2, no_momentum):
    if no_momentum:
        return equality_of_weights_and_biases(cnn_1, cnn_2)  
    else:
        return np.all([equality_of_weights_and_biases(cnn_1, cnn_2), 
                       equality_of_optimizer_state(cnn_1, cnn_2)])
            

In [9]:
def full_test():
    
    devices = [torch.device("cpu"), torch.device("cuda")]
    
    example_kwargs = [{'model_type': PyTorchMNISTCNN, 
                       'optimizer_type': 'SGD', 
                       'no_momentum': False},
                       {'model_type': PyTorchMNISTCNN, 
                       'optimizer_type': 'SGD', 
                       'no_momentum': False},
                       {'model_type': PyTorch2DUNet, 
                       'optimizer_type': 'SGD', 
                       'no_momentum': False},
                       {'model_type': PyTorch2DUNet, 
                       'optimizer_type': 'SGD', 
                       'no_momentum': True}, 
                       {'model_type': PyTorch2DUNet, 
                       'optimizer_type': 'RMSprop', 
                       'no_momentum': False},
                       {'model_type': PyTorch2DUNet, 
                       'optimizer_type': 'Adam', 
                       'no_momentum': False}]
    
    answers = []
                     
    for device in devices:
        
        for example_kwarg in example_kwargs:
            no_momentum = example_kwarg['no_momentum']
            cnn_1 = initialize_model(device=device, lr=0.01, **example_kwarg)
            cnn_2 = initialize_model(device=device, lr=0.01, **example_kwarg)
            
            cnn_1.train_epoch()
            cnn_2.train_epoch()
            
            # See that models are not equal.
            should_be_false = equality_of_models(cnn_1, cnn_2, no_momentum)
            
            model_weights_1 = cnn_1.get_tensor_dict()
            cnn_2.set_tensor_dict(model_weights_1)
            
            # See that models are now equal.
            should_be_true = equality_of_models(cnn_1, cnn_2, no_momentum)
            
            success = should_be_true and not should_be_false
            
            answers.append(success)
            
            print("Config: {}, {}".format(device, example_kwarg))
            print("Succesful?: {}\n".format(success))
    print("\n\nWas the whole test succesfull? {}".format(np.all(answers)))

In [10]:
full_test()

Config: cpu, {'optimizer_type': 'SGD', 'no_momentum': False, 'model_type': <class 'tfedlrn.collaborator.pytorchmodels.pytorchmnistcnn.PyTorchMNISTCNN'>}
Succesful?: True

Config: cpu, {'optimizer_type': 'SGD', 'no_momentum': False, 'model_type': <class 'tfedlrn.collaborator.pytorchmodels.pytorchmnistcnn.PyTorchMNISTCNN'>}
Succesful?: True

Config: cpu, {'optimizer_type': 'SGD', 'no_momentum': False, 'model_type': <class 'tfedlrn.collaborator.pytorchmodels.pytorch2dunet.PyTorch2DUNet'>}
Succesful?: True

Config: cpu, {'optimizer_type': 'SGD', 'no_momentum': True, 'model_type': <class 'tfedlrn.collaborator.pytorchmodels.pytorch2dunet.PyTorch2DUNet'>}
Succesful?: True

Config: cpu, {'optimizer_type': 'RMSprop', 'no_momentum': False, 'model_type': <class 'tfedlrn.collaborator.pytorchmodels.pytorch2dunet.PyTorch2DUNet'>}
Succesful?: True

Config: cpu, {'optimizer_type': 'Adam', 'no_momentum': False, 'model_type': <class 'tfedlrn.collaborator.pytorchmodels.pytorch2dunet.PyTorch2DUNet'>}
Succ