<img src="../../docs/images/dlsia.png" width=600 />

# Demo on how to save and load models

**Authors:** Eric Roberts and Petrus Zwart

**E-mail:** PHZwart@lbl.gov, EJRoberts@lbl.gov

This notebook highlights some basic functionality with the dlsia package.

Using the dlsia framework, we initialize convolutional neural networks, train each on a small dataset using the cpu, and show how to save and load them.

In [1]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from dlsia.core import helpers, custom_losses, train_scripts
from dlsia.core.networks import msdnet, tunet, tunet3plus, smsnet

import matplotlib.pyplot as plt

## Create & load data 

### Generate random data

Let's build some random data: 40 instances of single channel, 36-by-36 images.

In [2]:
n_imgs = 40
n_channels = 1
n_xy = 36

random_data1 = torch.rand((n_imgs, n_channels, n_xy, n_xy))
k = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, stride=1, padding=5)
random_data2 = torch.rand((n_imgs, n_channels, n_xy, n_xy))
random_data_gt = k(random_data1) 
random_data_obs = k(random_data1)  + random_data2*0.50
K = 3
random_data_obs = random_data_obs[:,:,K:32+K,K:32+K].detach()
random_data_gt = random_data_gt[:,:,K:32+K,K:32+K].detach()

train_x = random_data_obs[:20,...] 
train_y = random_data_gt[:20,...]
test_x = random_data_obs[20:,...]
test_y = random_data_gt[20:,...]

### Prep data 

We cast data as tensors for dlsia pipeline ingestion by making liberal use of the PyTorch Dataloader. This allows us to easy handle and iterative load data into the networks and models. 

In [3]:
train_set = TensorDataset( train_x, train_y)
test_set = TensorDataset( train_x, train_y)

# Specify batch sizes
batch_size_train = 20 
batch_size_test  = 20

# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
train_loader_params = {'batch_size': batch_size_train,'shuffle': True}
test_loader_params  = {'batch_size': batch_size_test, 'shuffle': False}

# Build Dataloaders
train_loader = DataLoader(train_set, **train_loader_params)
test_loader  = DataLoader(test_set, **test_loader_params)

## Construct Networks

dlsia offers a variety of different convolutional neural network architectures.

### MSDNet

Mixed-scale Dense networks that probe different length scales using dilated convolutions.

In [4]:
msdnet_model = msdnet.MixedScaleDenseNetwork(in_channels = 1,
                                             out_channels = 1, 
                                             num_layers=40)

### TUNet

Tuneable U-Nets with a variety of user-customizable parameters.

In [5]:
tunet_model = tunet.TUNet(image_shape=(32,32),
                          in_channels=1,
                          out_channels=1,
                          depth=3, 
                          base_channels=10)

### TUNet3+

A newer UNet modification connecting all encoder and decoder layers via carefully crafted upsampling/downsampling/convolution/concatenation bundles.

In [6]:
tunet3plus_model = tunet3plus.TUNet3Plus(image_shape=(32,32),
                                         in_channels=1,
                                         out_channels=1,
                                         depth=3,
                                         base_channels=10)

### SMSNets

Sparse Mixed-Scale Networks that lean, randomly & sparsely connected variants of MSDNets.

In [7]:
smsnet_model = smsnet.random_SMS_network(in_channels=1, 
                                         out_channels=1, 
                                         hidden_out_channels=[1],
                                         layers=40, 
                                         dilation_choices=[1,2,3,4],
                                         #layer_probabilities=layer_probabilities,
                                         network_type="Regression")

In [8]:
# View number of learnable parameters in each network
print("MSDNet :     ", helpers.count_parameters(msdnet_model), "parameters")
print("TUNet :      ", helpers.count_parameters(tunet_model), "parameters")
print("TUNet3plus : ", helpers.count_parameters(tunet3plus_model), "parameters")
print("SMSNet :     ", helpers.count_parameters(smsnet_model), "parameters")

MSDNet :      9182 parameters
TUNet :       46131 parameters
TUNet3plus :  49421 parameters
SMSNet :      735 parameters


## Training

### Training parameters

Training hyperparameters are chosen

In [9]:
epochs = 30             
criterion = nn.MSELoss()  
learning_rate = 1e-2

# Define optimizers, one per network
optimizer_msd        = optim.Adam(msdnet_model.parameters(), lr=learning_rate)
optimizer_tunet      = optim.Adam(tunet_model.parameters(), lr=learning_rate)
optimizer_tunet3plus = optim.Adam(tunet3plus_model.parameters(), lr=learning_rate)
optimizer_smsnet     = optim.Adam(smsnet_model.parameters(), lr=learning_rate)

device = "cpu" 
#device = helpers.get_device()  # Uncomment to get detected GPU

print('Device we will compute on: ', device)   # cuda:0 for GPU. Else, CPU

Device we will compute on:  cpu


### Training loops

In [10]:
msdnet_model.to(device)   
msdnet_model, results = train_scripts.train_regression(msdnet_model,
                                                       train_loader,
                                                       test_loader,
                                                       epochs,
                                                       criterion,
                                                       optimizer_msd,
                                                       device,
                                                       show=10)
msdnet_model = msdnet_model.cpu()

# clear out unnecessary variables from device (GPU) memory
#torch.cuda.empty_cache()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 3.0587e-02 | Validation Loss: 2.9307e-02
Training CC: 0.2502   Validation CC  : 0.2717 


Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 2.5751e-02 | Validation Loss: 2.5685e-02
Training CC: 0.3672   Validation CC  : 0.3708 


Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 2.4834e-02 | Validation Loss: 2.4734e-02
Training CC: 0.4068   Validation CC  : 0.4111 


In [11]:
tunet_model.to(device)   
tunet_model, results = train_scripts.train_regression(tunet_model,
                                                      train_loader,
                                                      test_loader,
                                                      epochs,
                                                      criterion,
                                                      optimizer_tunet,
                                                      device,
                                                      show=10)
tunet_model = tunet_model.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 2.1023e-02 | Validation Loss: 2.0195e-02
Training CC: 0.5513   Validation CC  : 0.5673 


Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.6087e-02 | Validation Loss: 1.5932e-02
Training CC: 0.6799   Validation CC  : 0.6858 


Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.4304e-02 | Validation Loss: 1.4126e-02
Training CC: 0.7219   Validation CC  : 0.7252 


In [12]:
tunet3plus_model.to(device)   
tunet3plus_model, results = train_scripts.train_regression(tunet3plus_model,
                                                           train_loader,
                                                           test_loader,
                                                           epochs,
                                                           criterion,
                                                           optimizer_tunet3plus,
                                                           device,
                                                           show=10)
tunet3plus_model = tunet3plus_model.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 2.9115e-02 | Validation Loss: 2.7833e-02
Training CC: 0.4941   Validation CC  : 0.5236 


Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.7936e-02 | Validation Loss: 1.7812e-02
Training CC: 0.6395   Validation CC  : 0.6462 


Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.5100e-02 | Validation Loss: 1.4969e-02
Training CC: 0.7016   Validation CC  : 0.7050 


In [13]:
smsnet_model.to(device)   
smsnet_model, results = train_scripts.train_regression(smsnet_model,
                                                       train_loader,
                                                       test_loader,
                                                       epochs,
                                                       criterion,
                                                       optimizer_smsnet,
                                                       device,
                                                       show=10)
smsnet_model = smsnet_model.cpu()

Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 1.9555e-02 | Validation Loss: 1.8536e-02
Training CC: 0.5966   Validation CC  : 0.6203 


Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.5011e-02 | Validation Loss: 1.4870e-02
Training CC: 0.7039   Validation CC  : 0.7076 


Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.3773e-02 | Validation Loss: 1.3678e-02
Training CC: 0.7334   Validation CC  : 0.7355 


## Save Networks

Each network library contains submodule for saving the trained networks. Each instance saves in a .pt file the following:

- model's state_dict: the network parameters learned through optimization/minimization during training,
- model's topo_dict: the list of network hyperparameters needed to initialize the same architecture.

This follows standard PyTorch practice; instead of saving massive trained networks, the pickled weights may simply be loaded into a freshly created network.

In [14]:
msdnet_model.save_network_parameters("this_msdnet.pt")
smsnet_model.save_network_parameters("this_smsnet.pt")
tunet_model.save_network_parameters("this_tunet.pt")

## Load networks from file

Each network library loads in the .pt file containing architecture-governing hyperparameters and learned weights.

In [15]:
copy_msdnet = msdnet.MSDNetwork_from_file("this_msdnet.pt")
copy_smsnet = smsnet.SMSNetwork_from_file("this_smsnet.pt")
copy_tunet = tunet.TUNetwork_from_file("this_tunet.pt")

### Verify loaded networks

Network copies are loaded from file and checked against the originals.

In [16]:
with torch.no_grad():
    r1 = msdnet_model(test_x)
    r2 = copy_msdnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8

In [17]:
with torch.no_grad():
    r1 = smsnet_model(test_x)
    r2 = copy_smsnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8

In [18]:
with torch.no_grad():
    r1 = tunet_model(test_x)
    r2 = copy_tunet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8