In [1]:
import torch
import torch.nn as nn
import numpy as np
from math import *

In [2]:
# activation function
def activation(x):
    return x * torch.sigmoid(x)

In [3]:
# build ResNet with one blocks
class Net(torch.nn.Module):
    def __init__(self,input_width,layer_width):
        super(Net,self).__init__()
        self.layer_in = torch.nn.Linear(input_width, layer_width)
        self.layer1 = torch.nn.Linear(layer_width, layer_width)
        self.layer2 = torch.nn.Linear(layer_width, layer_width)
        self.layer_out = torch.nn.Linear(layer_width, 1)
    def forward(self,x):
        y = self.layer_in(x)
        y = y + activation(self.layer2(activation(self.layer1(y)))) # residual block 1
        output = self.layer_out(y)
        return output

In [4]:
dimension = 1

In [5]:
input_width,layer_width = dimension, 4

In [6]:
net = Net(input_width,layer_width)

In [7]:
# load model parameters
pretrained_dict = torch.load('net_params_DGM.pkl', map_location = 'cpu')

# get state_dict
net_state_dict = net.state_dict()

# remove keys that does not belong to net_state_dict
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}

# update dict
net_state_dict.update(pretrained_dict_1)

# set new dict back to net
net.load_state_dict(net_state_dict)

<All keys matched successfully>

In [8]:
param_DGM = torch.Tensor([0.])
for name,param in net.named_parameters(): 
    size = param.size()
    if len(size) == 2:
        print(param.detach().cpu().view(1, size[0]*size[1])[0])
        param_DGM = torch.cat((param_DGM, param.detach().cpu().view(1, size[0]*size[1])[0]), dim = 0)
    else:
        print(param.detach().cpu().view(1, size[0])[0])
        param_DGM = torch.cat((param_DGM, param.detach().cpu().view(1, size[0])[0]), dim = 0)

tensor([-0.2179, -0.7552,  0.4469,  1.1701])
tensor([-0.3287,  0.4343, -0.5646,  0.1575])
tensor([-0.2099,  0.3082, -0.4854, -0.2760,  0.1566, -0.1074, -0.4325,  0.1162,
        -0.4324, -0.2633,  0.0114, -0.0235,  0.4624, -0.1638,  0.4224, -0.3636])
tensor([-0.3637,  0.2681, -0.4810,  0.1562])
tensor([ 0.1101, -0.0567,  0.2634, -0.4340, -0.4977, -0.3339, -0.2240, -0.2326,
         0.3019, -0.3750,  0.3615,  0.0318, -0.4894, -0.2256, -0.4991, -0.4308])
tensor([ 0.0697,  0.1143,  0.3273, -0.0598])
tensor([ 1.1556, -1.7318,  0.8062, -0.2845])
tensor([-0.2985])


In [9]:
# load model parameters
pretrained_dict = torch.load('net_params_DRM.pkl', map_location = 'cpu')

# get state_dict
net_state_dict = net.state_dict()

# remove keys that does not belong to net_state_dict
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}

# update dict
net_state_dict.update(pretrained_dict_1)

# set new dict back to net
net.load_state_dict(net_state_dict)

<All keys matched successfully>

In [10]:
param_DRM = torch.Tensor([0.])
for name,param in net.named_parameters(): 
    size = param.size()
    if len(size) == 2:
        print(param.detach().cpu().view(1, size[0]*size[1])[0])
        param_DRM = torch.cat((param_DRM, param.detach().cpu().view(1, size[0]*size[1])[0]), dim = 0)
    else:
        print(param.detach().cpu().view(1, size[0])[0])
        param_DRM = torch.cat((param_DRM, param.detach().cpu().view(1, size[0])[0]), dim = 0)

tensor([-1.5141,  0.6419,  1.7847, -0.0447])
tensor([-0.2507,  0.4893, -0.7086,  0.4046])
tensor([-0.2099,  0.3082, -0.4854, -0.2760,  0.1566, -0.1074, -0.4325,  0.1162,
        -0.4324, -0.2633,  0.0114, -0.0235,  0.4624, -0.1638,  0.4224, -0.3636])
tensor([-0.3637,  0.2681, -0.4810,  0.1562])
tensor([ 0.1101, -0.0567,  0.2634, -0.4340, -0.4977, -0.3339, -0.2240, -0.2326,
         0.3019, -0.3750,  0.3615,  0.0318, -0.4894, -0.2256, -0.4991, -0.4308])
tensor([ 0.0697,  0.1143,  0.3273, -0.0598])
tensor([ 0.4888,  0.5074, -0.1846, -0.9395])
tensor([-0.1878])


In [11]:
# load model parameters
pretrained_dict = torch.load('net_params_DRM_to_DGM.pkl', map_location = 'cpu')

# get state_dict
net_state_dict = net.state_dict()

# remove keys that does not belong to net_state_dict
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}

# update dict
net_state_dict.update(pretrained_dict_1)

# set new dict back to net
net.load_state_dict(net_state_dict)

<All keys matched successfully>

In [12]:
param_DRM_tilde = torch.Tensor([0.])
for name,param in net.named_parameters(): 
    size = param.size()
    if len(size) == 2:
        print(param.detach().cpu().view(1, size[0]*size[1])[0])
        param_DRM_tilde = torch.cat((param_DRM_tilde, param.detach().cpu().view(1, size[0]*size[1])[0]), dim = 0)
    else:
        print(param.detach().cpu().view(1, size[0])[0])
        param_DRM_tilde = torch.cat((param_DRM_tilde, param.detach().cpu().view(1, size[0])[0]), dim = 0)

tensor([-1.4766,  0.7718,  1.7552, -0.0758])
tensor([-0.2490,  0.7030, -0.7673,  0.3942])
tensor([-0.2099,  0.3082, -0.4854, -0.2760,  0.1566, -0.1074, -0.4325,  0.1162,
        -0.4324, -0.2633,  0.0114, -0.0235,  0.4624, -0.1638,  0.4224, -0.3636])
tensor([-0.3637,  0.2681, -0.4810,  0.1562])
tensor([ 0.1101, -0.0567,  0.2634, -0.4340, -0.4977, -0.3339, -0.2240, -0.2326,
         0.3019, -0.3750,  0.3615,  0.0318, -0.4894, -0.2256, -0.4991, -0.4308])
tensor([ 0.0697,  0.1143,  0.3273, -0.0598])
tensor([ 0.3549,  0.5734, -0.1157, -0.9831])
tensor([-0.0702])


In [13]:
error = torch.sum((param_DGM - param_DRM)**2)**0.5
print('L2 Distance of two DGM and DRM min is: ', error)
error_max = torch.max((param_DGM - param_DRM)**2)
print('Linf Distance of two DGM and DRM min is: ', error_max**0.5)

L2 Distance of two DGM and DRM min is:  tensor(3.7243)
Linf Distance of two DGM and DRM min is:  tensor(2.2392)


In [14]:
error = torch.sum((param_DGM - param_DRM_tilde)**2)**0.5
print('L2 Distance of two DGM and DRM_tilde min is: ', error)
error_max = torch.max((param_DGM - param_DRM_tilde)**2)
print('Linf Distance of two DGM and DRM_tilde min is: ', error_max**0.5)

L2 Distance of two DGM and DRM_tilde min is:  tensor(3.8342)
Linf Distance of two DGM and DRM_tilde min is:  tensor(2.3052)


In [15]:
error = torch.sum((param_DRM - param_DRM_tilde)**2)**0.5
print('L2 Distance of two DRM and DRM_tilde min is: ', error)
error_max = torch.max((param_DRM - param_DRM_tilde)**2)
print('Linf Distance of two DRM and DRM_tilde min is: ', error_max**0.5)

L2 Distance of two DRM and DRM_tilde min is:  tensor(0.3349)
Linf Distance of two DRM and DRM_tilde min is:  tensor(0.2138)
