# SELF-SUPERVISED DENOISING: PART ONE
### Authors: Claire Birnie and Sixiu Liu, KAUST

Author websites: 
- https://cebirnie92.github.io/ 
- https://swagroup.kaust.edu.sa/people/detail/sixiu-liu-(%E5%88%98%E6%80%9D%E7%A7%80))

## Tutorial Overview

On completion of this tutorial you will have learnt how to write your own blind-spot denoising procedure that is trained in a self-supervised manner, i.e., the training data is the same as the inference data with no labels required!

### Methodology Recap
We will implement the Noise2Void 

***

In [3]:
# Import necessary packages
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader 

# Our unet functions just to speed things up
from unet import UNet

***

# Step One - Data loading

In this example we are going to use a post-stack seismic section generated from the Hess VTI model. The post-stack section can be downloaded from: XXXX

***

# Step Two - Blindspot corruption of training data

In [None]:
# Create a function that randomly selects and corrupts pixels following N2V methodology
def multi_active_pixels(patch, 
                        num_activepixels, 
                        neighbourhood_radius=5, 
                        swap=True):

    n_rad = neighbourhood_radius  # descriptive variable name too long

    # Select multiple locations for active pixels
    idx_aps = np.random.randint(0, patch.shape[0], num_activepixels)
    idy_aps = np.random.randint(0, patch.shape[1], num_activepixels)

    # For each active pixel compute shift for finding neighbouring pixel and find pixel
    x_neigh_shft = np.random.randint(-n_rad // 2 + n_rad % 2, n_rad // 2 + n_rad % 2, num_activepixels)
    y_neigh_shft = np.random.randint(-n_rad // 2 + n_rad % 2, n_rad // 2 + n_rad % 2, num_activepixels)
    
    # OPTIONAL: don't allow replacement with itself
    for i in range(len(x_neigh_shft)):
        if x_neigh_shft[i] == 0 and y_neigh_shft[i] == 0:
            # This means its replacing itself with itself...not good!
            shft_options = np.trim_zeros(np.arange(-n_rad // 2 + 1, n_rad // 2 + 1))
            x_neigh_shft[i] = np.random.choice(shft_options[shft_options != 0], 1)

    # Find x and y locations of neighbours for the replacement
    idx_neigh = idx_aps + x_neigh_shft
    idy_neigh = idy_aps + y_neigh_shft

    # Make sure neighbouring pixels within window
    idx_neigh = idx_neigh + (idx_neigh < 0) * patch.shape[0] - (idx_neigh >= patch.shape[0]) * patch.shape[0]
    idy_neigh = idy_neigh + (idy_neigh < 0) * patch.shape[1] - (idy_neigh >= patch.shape[1]) * patch.shape[1]
    
    # combine x and y coordinates for active pixels and neighbouring pixels
    id_aps = (idx_aps, idy_aps)
    id_neigh = (idx_neigh, idy_neigh)
    
    # Make mask and corrupted patch
    mask = np.ones_like(patch)
    cp_ptch = patch.copy()
    mask[id_aps] = 0.
    if swap:
        cp_ptch[id_aps] = patch[id_neigh]
    else:
        cp_ptch[id_aps] = 0

    return cp_ptch, mask


In [None]:
# Check the corruption function works


***

# Step three - Set up network

In [None]:
# Select device for training
device = 'cpu'
if torch.cuda.device_count() > 0 and torch.cuda.is_available():
    print("Cuda installed! Running on GPU!")
    device = torch.device(torch.cuda.current_device())
    print(f'Device: {device} {torch.cuda.get_device_name(device)}')
else:
    print("No GPU available!")

In [None]:
# Build UNet
network = unet.UNet(input_channels=1, 
                    output_channels=1, 
                    hidden_channels=32, 
                    levels=4).to(device)

network = network.apply(weights_init) 

In [None]:
# Network initialisation steps
n_epochs = 101
lr = 0.0001
criterion = nn.L1Loss()
optim = torch.optim.Adam(network.parameters(), lr=lr)

***

# Step four - training

In [None]:
# Initial train and val 
train_loss_history = np.zeros(n_epochs)
train_accuracy_history = np.zeros(n_epochs)
val_loss_history = np.zeros(n_epochs)
val_accuracy_history = np.zeros(n_epochs)

# Create DataLoaders fixing the generator for reproducibily
g = torch.Generator()
g.manual_seed(0)

In [4]:
def n2v_train(model, 
              criterion, 
              optimizer, 
              data_loader, 
              device):
    
    model.train()
    accuracy = 0
    loss = 0

    for dl in tqdm(data_loader):
        X, y, mask = dl[0], dl[1], dl[2]
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        optimizer.zero_grad()

        yprob = model(X)

        ls = criterion(yprob * (1 - mask), y * (1 - mask))
        ls.backward()
        optimizer.step()
        with torch.no_grad():
            yprob = yprob
            ypred = (yprob.detach().cpu().numpy()).astype(float)
        loss += ls.item()
        accuracy += np.sqrt(np.mean((y.cpu().numpy().ravel( ) - ypred.ravel() )**2))
        
    loss /= len(data_loader)
    accuracy /= len(data_loader)   

    return loss, accuracy, dict({'psnr':psnr, 'rre':rre})