# <font color=blue>Artifact reduction network</font>

In [None]:
import os
import torch
import time
from torch import nn as nn
from torch.nn import functional as F
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt

### Load patches of data for train and test

In [None]:
patch_poly_train = torch.load('./patch_seg_poly_LR.pt')
patch_mono_train = torch.load('./patch_seg_monolabel_LR.pt')
patch_poly_val = torch.load('./patch_seg_poly_val_LR.pt')
patch_mono_val = torch.load('./patch_seg_monolabel_val_LR.pt')

### Custom dataset loader

In [None]:
class PCBcustomDataset(data.Dataset):
        """Characterizes PCB dataset"""
    def __init__(self, patches_data, patches_label):
        self.patches_data = patches_data
        self.patches_label = patches_label

    def __len__(self):
        """Denotes the total number of PCB samples"""
        return len(self.patches_data)

    def __getitem__(self, index):
        """Generates one sample of PCB data"""
        #select sample
        x_data = self.patches_data[index]
        y_label = self.patches_label[index]
        # Unsqueeze channel dimension
        x_data = x_data.unsqueeze(0)
        y_label = y_label.unsqueeze(0)
        return x_data, y_label

### Visualise the patches 

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12,12))
Volume =0 #patch 
slicee =9 #slice in the patch 
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)
im = ax1.imshow(patch_poly_train[Volume,slicee,:,:], cmap='gray')
im2 = ax2.imshow(patch_mono_train[Volume,slicee,:,:], cmap='gray')
ax1.title.set_text('Input')
ax2.title.set_text('Actual')

In [None]:
print( patch_poly_train.shape, patch_poly_val.shape)

### Data loader

In [None]:
dataset_pcb_train = PCBcustomDataset((patch_poly_train), (patch_mono_train))
dataset_pcb_val = PCBcustomDataset((patch_poly_val),(patch_mono_val))
print('train directory has {} samples'.format(len(dataset_pcb_train)))
print('val directory has {} samples'.format(len(dataset_pcb_val)))

In [None]:
trainloader = data.DataLoader(dataset_pcb_train, batch_size=4, shuffle= True, num_workers=1,drop_last= False)#4
valloader =data.DataLoader(dataset_pcb_val, batch_size=4, shuffle= True, num_workers=1,drop_last= False)#4

## Model

In [None]:
class Autoencoder_VARN(nn.Module):
    def __init__(self):
        super(Autoencoder_VARN, self).__init__()
        # encoder
        self.downlayer1 = nn.Sequential(nn.Conv3d(1, 12, kernel_size=(3, 5, 5), padding=(1, 2, 2), stride=(1, 2, 2)),
                                        nn.ReLU())
        self.downlayer2 = nn.Sequential(nn.Conv3d(12, 24, kernel_size=(3, 5, 5), padding=(1, 2, 2), stride=(1, 2, 2)),
                                        nn.ReLU())
        self.downlayer3 = nn.Sequential(nn.Conv3d(24, 48, kernel_size=(3, 5, 5), padding=(1, 2, 2), stride=(1, 2, 2)),
                                        nn.ReLU())
        self.downlayer4 = nn.Sequential(nn.Conv3d(48, 96, kernel_size=(3, 5, 5), padding=(1, 2, 2), stride=(1, 2, 2)),
                                        nn.ReLU())

        # decoder
        self.bottleneck = nn.Sequential(nn.Conv3d(96, 96, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
                                        nn.ReLU())
        self.aux_conv = nn.Sequential(nn.Conv3d(96, 1, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
                                      )
        self.uplayer0 = nn.Sequential(nn.Conv3d(96, 48, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
                                      nn.ReLU())
        self.uplayer1 = nn.Sequential(nn.Conv3d(48, 24, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
                                      nn.ReLU())
        self.upsample1 = nn.Upsample(size=(5, 125, 125))
        self.uplayer2 = nn.Sequential(nn.Conv3d(24, 12, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
                                      nn.ReLU())
        self.upsample2 = nn.Upsample(scale_factor=(1, 2, 2), mode="trilinear", align_corners=True)
        self.uplayer3 = nn.Sequential(nn.Conv3d(12, 1, kernel_size=(3, 5, 5), padding=(1, 2, 2)))

    def forward(self, x):
        x_original = x
        d, h, w = x_original.shape[2:]
        x = self.downlayer1(x)
        x_d1 = x
        x = self.downlayer2(x)
        x_d2 = x
        x = self.downlayer3(x)
        x_d3 = x
        x = self.downlayer4(x)
        x = self.bottleneck(x)
        x_aux = self.aux_conv(x)
        x_aux = F.interpolate(x_aux, size=(d, h, w), mode="trilinear", align_corners=True)
        x = F.interpolate(x, size=(x_d3.shape[2], x_d3.shape[3], x_d3.shape[4]), mode="trilinear", align_corners=True)
        x = self.uplayer0(x)
        x = x + x_d3
        x = F.interpolate(x, size=(x_d2.shape[2], x_d2.shape[3], x_d2.shape[4]), mode="trilinear", align_corners=True)
        x = self.uplayer1(x)
        x = x + x_d2
        x = F.interpolate(x, size=(x_d1.shape[2], x_d1.shape[3], x_d1.shape[4]), mode="trilinear", align_corners=True)
        x = self.uplayer2(x)
        x = x + x_d1
        x = F.interpolate(x, size=(d, h, w), mode="trilinear", align_corners=True)
        x = self.uplayer3(x)
        x += x_original
        if self.training:
            return x, x_aux
        else:
            return x

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv3d):
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0.0)

In [None]:
net = nn.DataParallel(Autoencoder_VARN().cuda())
net.apply(weights_init)

## Optimizer & Loss

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
costfunction = nn.SmoothL1Loss(reduction='mean')
optimizer = torch.optim.Adam(net.parameters(),lr=--, weight_decay=--)

## Training

In [None]:
# SSIM metric
from pytorch_msssim import *
from tqdm import tqdm_notebook as tqdm

In [None]:
print('--- Training of VARN--- ')
training_start_time = time.time()
num_epochs = 500
weight = 0.3
train_losses, val_losses = [], []
train_acc, val_acc = [], []
val_loss_temp = 1
for e in tqdm(range(num_epochs)):
    logs = {}
    running_loss = 0
    ssim_train = 0
    net.train()
    for i, data_samples in enumerate(trainloader):
        volume, labels = data_samples
        volume = volume.cuda()
        labels = labels.cuda()

        output, output_aux = net(volume.float())
        loss1 = costfunction(output, labels.float())
        loss_aux = costfunction(output_aux, labels.float())
        loss = loss1 + weight * loss_aux
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    else:
        val_loss = 0
        ssim_val = 0
        net.eval()
        with torch.no_grad():
            volume, labels = next(iter(valloader))
            for i, data_samples in enumerate(valloader):
                volume, labels = data_samples
                volume = ((volume).cuda())
                labels = labels.cuda()
                outputs = net(volume)
                val_loss += costfunction(outputs, labels)

                accuracy = SSIM_accuracy(outputs, labels, data_range=labels.max() - labels.min())
                ssim_val += accuracy.item()

        # save best model        
        if (val_loss / len(valloader)) < val_loss_temp:
            torch.save({
                'epoch': e,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': (val_loss / len(valloader)),
            }, './AR_aux_loss_bestval_LR.pt')  
            val_loss_temp = (val_loss / len(valloader))

        train_losses.append(running_loss / len(trainloader))
        val_losses.append(val_loss / len(valloader))
        train_acc.append(ssim_train / len(trainloader))
        val_acc.append(ssim_val / len(valloader))
        print("Epoch: {}/{}.. ".format(e + 1, num_epochs),
              "Training Loss: {:.3f}.. ".format(running_loss / len(trainloader)),
              "val Loss: {:.3f}.. ".format(val_loss / len(valloader)),
              "SSIM metric: {:.3f}.. ".format(ssim_val / len(valloader)))

# save complete model        
torch.save(net.state_dict(), './AR_aux_loss_LR.pt')
print('Training finished in {}'.format(time.time() - training_start_time))

In [None]:
fig = plt.figure(figsize=(9, 7))
ax = fig.add_subplot(111)
ax.plot(train_losses, label="Training Loss")
ax.plot(val_losses, label="Validation Loss")
ax.set_xlabel("epochs")
ax.set_ylabel("loss")
ax.set_title("Loss vs Epochs")
ax.legend()

### Evaluation

In [None]:
model = nn.DataParallel(Autoencoder_VARN().cuda())
checkpoint = torch.load('./AR_aux_loss_bestval_LR.pt')
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

In [None]:
dataset_pcb_test = PCBcustomDataset((patch_poly_val),(patch_mono_val))
testloader =data.DataLoader(dataset_pcb_test, batch_size=1,shuffle=False, num_workers=1,drop_last= False)

In [None]:
mse_loss = nn.MSELoss(reduction ='mean')
testloss, ssim, mse,smoothl1 = [],[],[],[]
model.eval()
with torch.no_grad():
    for i, data_samples in tqdm(enumerate(testloader)):
        volume, labels = data_samples
        volume = ((volume).cuda())
        labels = labels.cuda()
        outputs = model(volume)
        ssim_accuracy = SSIM_accuracy(outputs, labels, data_range=20)
        print(ssim_accuracy)
        ssim.append(ssim_accuracy.detach().cpu())
        mse_accuracy = mse_loss(outputs, labels)
        smoothl1_loss = costfunction(outputs, labels)
        smoothl1.append(smoothl1_loss.detach().cpu())
        mse.append(mse_accuracy.detach().cpu())
print('Mean SSIM accuracy is {}'.format(np.mean(ssim), np.std(ssim)))
        

In [None]:
%matplotlib notebook
slicee = 20
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,3,2)
ax3 = fig.add_subplot(1,2,2)
im = ax1.imshow(pred_poly[0,0,slicee,:,:],cmap='gray')
clim=im.properties()['clim']
ax2.imshow(pred_mono[0,0,slicee,:,:], clim=clim, cmap='gray')
ax3.imshow(prediction[0,0,slicee,:,:].cpu().detach().numpy(), clim = clim, cmap= 'gray')
fig.colorbar(im, ax=(ax1,ax2,ax3), shrink=0.2)
ax1.title.set_text('Artifact volume')
ax2.title.set_text('Ground truth')
ax3.title.set_text('Predicted')

plt.show()