<img src="../../docs/images/dlsia.png" width=600 />

# Demo on how to save and load models

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov
___

This notebook highlights some basic functionality with the pyMSDtorch package.

Using the pyMSDtorch framework, we initialize convolutional neural networks, and show how to save and load them


In [1]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from dlsia.core import helpers, custom_losses, train_scripts
from dlsia.core.networks import msdNet, tunet, tunet3plus, smsnet

import matplotlib.pyplot as plt

Build random data for demo purpose

In [2]:
random_data1 = torch.rand((40, 1, 36, 36))
k = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, stride=1, padding=5)
random_data2 = torch.rand((40, 1, 36, 36))
random_data_gt = k(random_data1) 
random_data_obs = k(random_data1)  + random_data2*0.50
K = 3
random_data_obs = random_data_obs[:,:,K:32+K,K:32+K].detach()
random_data_gt = random_data_gt[:,:,K:32+K,K:32+K].detach()

train_x = random_data_obs[:20,...] 
train_y = random_data_gt[:20,...]
test_x = random_data_obs[20:,...]
test_y = random_data_gt[20:,...]

train_set = TensorDataset( train_x, train_y)
test_set = TensorDataset( train_x, train_y)


# Specify batch sizes
batch_size_train = 20 
batch_size_test  = 20

# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
train_loader_params = {'batch_size': batch_size_train,'shuffle': True}
test_loader_params  = {'batch_size': batch_size_test, 'shuffle': False}

# Build Dataloaders
train_loader = DataLoader(train_set, **train_loader_params)
test_loader  = DataLoader(test_set, **test_loader_params)

Construct Networks

In [18]:
msdnet = MSDNet.MixedScaleDenseNetwork(in_channels = 1,
                                       out_channels = 1, 
                                       num_layers=40)
print("MSDNet :     ", helpers.count_parameters(msdnet), "parameters")

tunet = TUNet.TUNet(image_shape=(32,32),
                    in_channels=1,
                    out_channels=1,
                    depth=3, base_channels=10)
print("TUNet :      ", helpers.count_parameters(tunet), "parameters")

tunet3plus = TUNet3Plus.TUNet3Plus(image_shape=(32,32),
                                  in_channels=1,
                                  out_channels=1,
                                  depth=3,
                                  base_channels=10)
print("TUNet3plus : ", helpers.count_parameters(tunet3plus), "parameters")

smsnet = SMSNet.random_SMS_network(in_channels=1, 
                                   out_channels=1, 
                                   hidden_out_channels=[1],
                                   layers=40, 
                                   dilation_choices=[1,2,3,4],
                                   #layer_probabilities=layer_probabilities,
                                   network_type="Regression")

print("SMSNet :     ", helpers.count_parameters(smsnet), "parameters")

MSDNet :      7462 parameters
TUNet :       46131 parameters
TUNet3plus :  49421 parameters
SMSNet :      523 parameters


In [19]:
epochs = 30             
criterion = nn.L1Loss()  
LEARNING_RATE = 1e-2

# Define optimizers, one per network
optimizer_msd        = optim.Adam(msdnet.parameters(), lr=LEARNING_RATE)
optimizer_tunet      = optim.Adam(tunet.parameters(), lr=LEARNING_RATE)
optimizer_tunet3plus = optim.Adam(tunet3plus.parameters(), lr=LEARNING_RATE)
optimizer_smsnet     = optim.Adam(smsnet.parameters(), lr=LEARNING_RATE)

device = "cpu" 


In [20]:
msdnet.to(device)   
msdnet, results = train_scripts.train_regression(msdnet,
                                                 train_loader,
                                                 test_loader,
                                                 epochs,
                                                 criterion,
                                                 optimizer_msd,
                                                 device,
                                                 show=10)
msdnet = msdnet.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 3.0971e-01 | Validation Loss: 3.3526e-01
Training CC: 0.3246   Validation CC  : 0.3502 
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.6845e-01 | Validation Loss: 1.5185e-01
Training CC: 0.4390   Validation CC  : 0.4625 
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.0159e-01 | Validation Loss: 1.0797e-01
Training CC: 0.6644   Validation CC  : 0.6773 


In [21]:
tunet.to(device)   
tunet, results = train_scripts.train_regression(tunet,
                                                 train_loader,
                                                 test_loader,
                                                 epochs,
                                                 criterion,
                                                 optimizer_tunet,
                                                 device,
                                                 show=10)
tunet = tunet.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 5.6353e-01 | Validation Loss: 5.1034e-01
Training CC: -0.2148   Validation CC  : -0.2249 
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.9419e-01 | Validation Loss: 1.7824e-01
Training CC: -0.0490   Validation CC  : -0.0185 
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.4800e-01 | Validation Loss: 1.4394e-01
Training CC: 0.4056   Validation CC  : 0.4199 


In [22]:
tunet3plus.to(device)   
tunet3plus, results = train_scripts.train_regression(tunet3plus,
                                                 train_loader,
                                                 test_loader,
                                                 epochs,
                                                 criterion,
                                                 optimizer_tunet3plus,
                                                 device,
                                                 show=10)
tunet3plus = tunet3plus.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 1.3394e-01 | Validation Loss: 1.3871e-01
Training CC: 0.3538   Validation CC  : 0.3833 
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.0467e-01 | Validation Loss: 1.0212e-01
Training CC: 0.6365   Validation CC  : 0.6568 
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 9.2951e-02 | Validation Loss: 9.2530e-02
Training CC: 0.7260   Validation CC  : 0.7311 


In [23]:
smsnet.to(device)   
smsnet, results = train_scripts.train_regression(smsnet,
                                                 train_loader,
                                                 test_loader,
                                                 epochs,
                                                 criterion,
                                                 optimizer_smsnet,
                                                 device,
                                                 show=10)
smsnet = smsnet.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 1.4356e-01 | Validation Loss: 1.4924e-01
Training CC: 0.5461   Validation CC  : 0.5648 
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.0884e-01 | Validation Loss: 1.0634e-01
Training CC: 0.6762   Validation CC  : 0.6854 
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 9.8020e-02 | Validation Loss: 9.6028e-02
Training CC: 0.7339   Validation CC  : 0.7363 


In [24]:
msdnet.save_network_parameters("this_msdnet.pt")
smsnet.save_network_parameters("this_smsnet.pt")
tunet.save_network_parameters("this_tunet.pt")


In [25]:
copy_msdnet = MSDNet.MSDNetwork_from_file("this_msdnet.pt")
copy_smsnet = SMSNet.SMSNetwork_from_file("this_smsnet.pt")
copy_tunet = TUNet.TUNetwork_from_file("this_tunet.pt")

In [26]:
with torch.no_grad():
    r1 = msdnet(test_x)
    r2 = copy_msdnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8

In [27]:
with torch.no_grad():
    r1 = smsnet(test_x)
    r2 = copy_smsnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8

In [28]:
with torch.no_grad():
    r1 = tunet(test_x)
    r2 = copy_tunet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8