<img src="../../docs/images/pymsdtorch.png" width=600 />

# Supervised Image Denosing

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov
___

This notebook highlights some basic functionality with the pyMSDtorch package.

In this notebook we setup a Mixed Scaled Dense Network and train it to denoising image corrupted by Gaussian noise. Subsequently, we will train a number Randomized Sparse Networks on the same task and show how to obtain error estimates via ensemble methods.

___


In [None]:
import sys
import os
import numpy as np
import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

import h5py
from torchsummary import summary

In [None]:
from dlsia.core import helpers
from dlsia.core.train_scripts import train_regression
from dlsia.core.networks import msdnet, smsnet, baggins
from dlsia.test_data.twoD import build_test_data, torch_hdf5_loader
from dlsia.viz_tools import plots, draw_sparse_network

import matplotlib.pyplot as plt

## Create Data

We produce noisysingle-class data consisting of peaks in guassian noise. 

Parameters to toggle:

- batch_size -- choice of 64 is optimized for 24 GB graphics card
- N_imgs -- number of images in training set
- N_peaks -- number of circular peaks in each image
- N_xy -- size of images
- SNR -- signal-to-noise ratio; more noise for lower number

In [None]:
### Some Parameters to Define ###
#################################
makeData = True     
batch_size = 100
num_workers = 0
showNoisyData = True
use_scaled_data = True

In [None]:
### Generate Data ###
#####################

if makeData == True:
    N_imgs = 200 
    N_peaks = 8
    N_xy = 32
    SNR=3
    mask_radius = 1.0
    
    build_test_data.build_data_standard_sets_2d(n_imgs=N_imgs,
                                                n_peaks=N_peaks,
                                                n_xy=N_xy, 
                                                snr=SNR,
                                                mask_radius=mask_radius)

Data generator class above can generate the following:
- trax_GT -- ground truth
- trax_obs -- obstructed, noisy images
- trax_obs_norm -- noisy images linearly scaled to interal [0,1]
- trax_mask -- binary masked images indicating peak (1) or background (0)

In [None]:
### Load Data ###
#################

if use_scaled_data == True:
    x_label = "trax_obs_norm"
else:
    x_label = "trax_obs"
    
f_train = "train_data_2d.hdf5"
f_test  = "test_data_2d.hdf5"
f_validation = "validate_data_2d.hdf5"

MyData_train = torch_hdf5_loader.Hdf5Dataset2D(filename=f_train, 
                                                  x_label=x_label, 
                                                  y_label="trax_GT")
MyData_validation = torch_hdf5_loader.Hdf5Dataset2D(filename=f_validation, 
                                                       x_label=x_label, 
                                                       y_label="trax_GT")
MyData_test = torch_hdf5_loader.Hdf5Dataset2D(filename=f_test, 
                                                 x_label=x_label, 
                                                 y_label="trax_GT")

loader_params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_workers}
train_loader = DataLoader(MyData_train, **loader_params)
loader_params = {'batch_size': batch_size, 'shuffle': False, 'num_workers': num_workers}
validation_loader = DataLoader(MyData_validation, **loader_params)
test_loader = DataLoader(MyData_test, **loader_params)

In [None]:
### Show Noisy Data ###
#######################
        
if showNoisyData == True:
    for batch in train_loader:
        noisy, mask = batch
                
        plt.figure(figsize=(12,10))
        plt.subplot(321)
        plt.imshow(noisy[0,0,:,:]); plt.colorbar(shrink=0.8); plt.title('Noisy'); 
        plt.subplot(322);
        plt.imshow(mask[0,0,:,:]); plt.colorbar(shrink=0.8); plt.title('Mask'); 
        plt.subplot(323)
        plt.imshow(noisy[1,0,:,:]); plt.colorbar(shrink=0.8)
        plt.subplot(324);
        plt.imshow(mask[1,0,:,:]); plt.colorbar(shrink=0.8) 
        plt.subplot(325)
        plt.imshow(noisy[2,0,:,:]); plt.colorbar(shrink=0.8)
        plt.subplot(326);
        plt.imshow(mask[2,0,:,:]); plt.colorbar(shrink=0.8) 

        break
        
    plt.rcParams.update({'font.size': 18})
    plt.tight_layout()
    plt.show()

## Create mixed-scale dense convolutional neural network

Lots of options to customize. See pyMSDtorch/core/networks/MSDNet.py 

In [None]:
### Build Networks ###
#####################

#N_xyz=20
in_channels = 1
out_channels = 1
num_layers = 40             # Benchmarks used 100, 150, and 200
layer_width = 1    # Usually 1
max_dilation = 5           # Set to 10 in Pelt, Sethian paper
activation = nn.ReLU()
normalization = nn.BatchNorm2d
final_layer = None

msdnet = MSDNet.MixedScaleDenseNetwork(in_channels = in_channels,
                                      out_channels = out_channels, 
                                      num_layers=num_layers, 
                                      layer_width=layer_width,
                                      max_dilation = max_dilation, 
                                      activation=activation,
                                      normalization=normalization,
                                      final_layer=final_layer,
                                      convolution=nn.Conv2d
                                   )
pytorch_total_params = sum(p.numel() for p in msdnet.parameters() if p.requires_grad)
print("Total number of refineable parameters: ", pytorch_total_params)

## Train the model

We perform the following:

- specify number of epochs,
- select the L2 MSE loss as our scroing criteria,
- select a learning rate of 1/1000
- choose the popular Adam optimizer for traversing the loss terrain

In [None]:
epochs = 100   # set number of epochs
criterion = nn.L1Loss()

LEARNING_RATE = 1e-2

optimizer = optim.Adam(msdnet.parameters(), lr=LEARNING_RATE)


In [None]:
device = helpers.get_device()
msdnet.to(device)
msdnet, results = train_regression(msdnet,
                                   train_loader,   
                                   validation_loader, 
                                   epochs, 
                                   criterion, 
                                   optimizer, 
                                   device=device, 
                                   show=10) 
torch.cuda.empty_cache()

In [None]:
plots.plot_training_results_regression(results)

## Saving network and testing our model

- Below, we save model parameters using torch.save
    - You have option of saving full model, but save parameters and loading them into a newly instantiated network is more flexible

- Finally, we load testing data, pass it through the network, and save results as .png

In [None]:
def regression_metrics( preds, target):
    """ 
    Here, the Pearson correlation coefficient is calulated between the network 
    predictions and the ground truth.
    """
    tmp = corcoef.cc(preds.cpu().flatten(), target.cpu().flatten() )
    return(tmp)


def segment_imgs(testloader, net, plot=True, std=False):
    """
    This function makes network predictions on testing data found in the 'testloader'
    pytorch dataloader object.
    
    :param testloader: the pyTorch dataloader object used to retrieve testing data
    :param net: the trained deep network
    :param plot: do you want to plot the first 10 network results using matplotlib?
    
    :returns seg_imgs: the predicted images, concatenated into a single tensor
    :returns noisy_imgs: the input images, concatenated into a single tensor
    :returns target_imgs: the ground truth images, concatenated into a single tensor
    """
    torch.cuda.empty_cache()
    
    seg_imgs = []
    noisy_imgs = []
    #target_imgs = []
    
    #running_CC_test_val = 0.0 
    
    counter = 0
    with torch.no_grad():
        for batch in testloader:
            noisy, target = batch

            noisy = torch.FloatTensor(noisy)
            noisy = noisy.to(device)#.unsqueeze(1)

            sigmas = None
            if not std:                
                output = net.to(device)(noisy)
            else:
                output, sigmas = net.to(device)(noisy, 'cpu', True)

            if counter == 0:
                seg_imgs = output.detach().cpu()
                noisy_imgs = noisy.detach().cpu()
                target_imgs = target.detach().cpu()
                if std:
                    sigmas = sigmas.detach().cpu()
            else:
                seg_imgs = torch.cat((seg_imgs, output.detach().cpu()), 0)
                noisy_imgs = torch.cat((noisy_imgs, noisy.detach().cpu()), 0)
                target_imgs = torch.cat((target_imgs, target.detach().cpu()), 0)
                if std:
                    sigmas = sigmas.detach().cpu()


            counter+=1
            
            if plot==True:
                for j in range(10):
                    if not std:
                        print(f'Images for batch # {counter}, number {j}')
                        plt.figure(figsize=(22,5))
                        plt.subplot(131)
                        plt.imshow(noisy.cpu()[j,0,:,:].data); plt.colorbar(shrink=0.8); plt.title('Noisy');             
                        plt.subplot(132)            
                        plt.imshow(output[j,0,:,:].detach().cpu()); plt.colorbar(shrink=0.8); plt.title('Prediction');            
                        plt.subplot(133)            
                        plt.imshow(target.cpu()[j,0,:,:].data); plt.colorbar(shrink=0.8); plt.title('Ground Truth'); 


                        plt.suptitle("MSDNet Predictions", size=24)
                        plt.rcParams.update({'font.size': 16})
                        plt.tight_layout()

                        plt.show()
                    else:
                        print(f'Images for batch # {counter}, number {j}')
                        plt.figure(figsize=(22,5))
                        plt.subplot(151)
                        plt.imshow(noisy.cpu()[j,0,:,:].data); 
                        plt.colorbar(shrink=0.8); plt.title('Noisy');             
                        plt.subplot(152)            
                        plt.imshow(output[j,0,:,:].detach().cpu()); 
                        plt.colorbar(shrink=0.8); plt.title('Prediction');            
                        plt.subplot(153)            
                        plt.imshow(sigmas[j,0,:,:].detach().cpu()); 
                        plt.colorbar(shrink=0.8); plt.title('Sigmas');                            
                        plt.subplot(154)            
                        plt.imshow(output[j,0,:,:].detach().cpu() / sigmas[j,0,:,:].detach().cpu(),vmax=30); 
                        plt.colorbar(shrink=0.8); plt.title('Signal to Noise');                            
                        
                        plt.subplot(155)            
                        plt.imshow(target.cpu()[j,0,:,:].data); plt.colorbar(shrink=0.8); plt.title('Ground Truth'); 


                        plt.suptitle("MSDNet Predictions", size=24)
                        plt.rcParams.update({'font.size': 16})
                        plt.tight_layout()

                        plt.show()

                        
                
    
    #CC = running_CC_test_val / len(testloader)
    torch.cuda.empty_cache()
    return seg_imgs, noisy_imgs, target_imgs

In [None]:
output, noisy, target  = segment_imgs(test_loader, msdnet)

In [None]:
in_channels = 1 # RGB input image
out_channels = 1 # binary output
num_layers = 40
alpha = 0.20 
gamma = 0.0
max_k = 6
min_k = 3
hidden_out_channels = [1] 
dilation_choices = [1,2,3,4,5] 
layer_probabilities={'LL_alpha':alpha,
                     'LL_gamma': gamma,
                     'LL_max_degree':max_k,
                     'LL_min_degree':min_k,
                     'IL': 0.25,
                     'LO': 0.25,
                     'IO': False}
sizing_settings = {'stride_base':2, #better keep this at 2
                   'min_power': 0,
                   'max_power': 0}
network_type = "Regression"
nets = [] 
N_networks = 7

for ii in range(N_networks):
    torch.cuda.empty_cache()
    print("Network %i"%(ii+1))
    net = SMSNet.random_SMS_network(in_channels=in_channels,
                                    out_channels=out_channels,
                                    in_shape=(32,32),
                                    out_shape=(32,32),
                                    sizing_settings=sizing_settings,
                                    layers=num_layers,
                                    dilation_choices=dilation_choices,
                                    hidden_out_channels=hidden_out_channels,
                                    layer_probabilities=layer_probabilities,
                                    network_type=network_type)
    
    # lets plot the network
    net_plot,dil_plot,chan_plot = draw_sparse_network.draw_network(net)
    plt.show()

    nets.append(net)
    
    print("Start training")
    pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print("Total number of refineable parameters: ", pytorch_total_params)
    epochs = 100                       # Set number of epochs
    criterion = nn.L1Loss()   # For segmenting 
    LEARNING_RATE = 1e-2
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)    
    device = helpers.get_device()
    net = net.to(device)
    tmp = train_regression(net,
                           train_loader,
                           test_loader,
                           epochs,
                           criterion,
                           optimizer,
                           device,
                           show=10)    
    net = net.cpu()
    plots.plot_training_results_regression(tmp[1]).show()
    

In [None]:
bagged_model = baggins.model_baggin(nets,"regression", False)

In [None]:
output, noisy, target  = segment_imgs(test_loader, bagged_model, std=True)