# Deep Recursive Transformer

### Imports

In [None]:
import torch
import torch.nn as nn
import torch.cuda as cuda
import torch.optim as optim
from torchvision import datasets, transforms
from timm.models.layers import trunc_normal_
from torch.utils.data import DataLoader
from torchsummary import summary

import numpy as np
import cv2 as cv
from PIL import Image
import time
import warnings
import time

from transformer_block import TransformerBlock, PatchEmbed, PatchUnEmbed
from data_loader import Rain800TrainData, Rain800ValData
from my_utils import batch_PSNR, batch_SSIM, output_to_image
from my_utils import save_ckp, load_ckp, base_path

### Global Parameters

In [3]:
training_image_size = 56
dtype = torch.cuda.FloatTensor
batch_size = 5
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)
epochs = 4600
lr = 0.0001
error_plot_freq = 20
INT_MAX = 2147483647
error_tolerence = 10
patch_size = 1

#paths
base_pth = base_path()
ckp_pth = base_pth + "/pretrained"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Load Data

In [49]:
### Prepare Data for Training
# train_dataset = Rain800TrainData(training_image_size, dataset_dir='/Rain-800/') #/Rain100L-Train/
# train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=False)

Loading Test1200 from Hi-Net Data


### Model Design

We inherit the swin transformer block with a modification: WSA -> CONV2d for our network. Here we design the recursive network such that every transformer blocks in the same residual units share the same weigh. We simply stack these residual units recursiely.

In [4]:
#patch embedding -> transformer -> patch unembedding
class BasicBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, patch_size):
        super().__init__()
        self.dim = dim
        self.patch_size = patch_size
        self.input_resolution = input_resolution
        H, W = self.input_resolution
        self.num_heads = num_heads
        self.transformer = TransformerBlock(self.dim, (H//self.patch_size, W//self.patch_size),
                                            num_heads) #patch size is 4
    
    def forward(self, x):
        x=self.transformer(x)
        return x

In [5]:
#Basic Block -> basic block (with skip connection)
class ResidualLayer(nn.Module):
    def __init__(self, dim, input_resolution, residual_depth, patch_size):
        super().__init__()
        self.dim = dim
        self.patch_size = patch_size
        self.residual_depth = residual_depth
        self.input_resolution = input_resolution
        self.block1 = BasicBlock(self.dim, self.input_resolution, 2, self.patch_size) #multi-heads: 2
        self.block2 = BasicBlock(self.dim, self.input_resolution, 2, self.patch_size)
        self.conv_out = nn.Conv2d(self.dim, self.dim, 3, padding = 1)
        
    def forward(self, x):
        B, HW, C = x.shape
        H, W = self.input_resolution
        shortcut = x
        for _ in range(self.residual_depth):
            x = self.block1(self.block2(x))
            x = torch.add(x, shortcut)
        #convolution at the end of each residual block
        x = x.transpose(1,2).view(B, C, H//self.patch_size, W//self.patch_size)
        x = self.conv_out(x).flatten(2).transpose(1,2)#B L C
        return x

In [6]:
#recursive network based on residual units
class DeepRecursiveTransformer(nn.Module):
    def __init__(self, dim, input_resolution, patch_size, residual_depth, recursive_depth):
        super().__init__()
        self.dim = dim
        self.patch_size = patch_size
        self.recursive_depth = recursive_depth
        self.input_resolution = input_resolution
        self.residual_depth = residual_depth
        self.H, self.W = self.input_resolution
        assert self.H == self.W, "Input hight and width should be the same"
        self.input_conv1 = nn.Conv2d(3, self.dim, 3, padding=1)
        self.patch_embed = PatchEmbed(img_size=self.H, patch_size = self.patch_size,
                                      in_chans=3, embed_dim=self.dim)
        self.patch_unembed = PatchUnEmbed(img_size=self.H, patch_size = self.patch_size,
                                          in_chans=self.dim, unembed_dim=3)
        self.recursive_layers = nn.ModuleList()
        for i in range(self.recursive_depth):
            layer = ResidualLayer(self.dim, self.input_resolution, self.residual_depth, self.patch_size)
            self.recursive_layers.append(layer)
        self.output_conv1 = nn.Conv2d(self.dim, 3, 3, padding=1)
        #use imagenet mean and std for general domain normalisation
        self.normalise_layer = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.denormalise_layer = transforms.Normalize((-0.485, -0.456, -0.406), (1./0.229, 1./0.224, 1./0.225))
        self.apply(self._init_weights)
        self.activation = nn.LeakyReLU()
        
    #weight initialisation scheme
    def _init_weights(self, l):
        if isinstance(l, nn.Linear):
            trunc_normal_(l.weight, std=.02)
            if isinstance(l, nn.Linear) and l.bias is not None:
                nn.init.constant_(l.bias, 0)
        elif isinstance(l, nn.LayerNorm):
            nn.init.constant_(l.bias, 0)
            nn.init.constant_(l.weight, 1.0)
    
    def forward(self, x):
        #normalise the data, input shape (B, C, H, W)
        x = self.normalise_layer(x)
        outer_shortcut = x
        x = self.patch_embed(x)
        inner_shortcut = x

        for i in range(len(self.recursive_layers)):
            x = self.recursive_layers[i](x)
            
        x=torch.add(x, inner_shortcut)
        x=self.patch_unembed(x, (self.H//self.patch_size,self.W//self.patch_size))
        x=torch.add(x, outer_shortcut)
        x=self.denormalise_layer(x)
        return x #output shape (B, C, H, W)

In [9]:
dim = 96
input_shape = (training_image_size, training_image_size)
patch_size = 1
residual_depth = 3
recursive_depth = 6
net = DeepRecursiveTransformer(dim, input_shape, patch_size, residual_depth, recursive_depth)
# summary(net.cuda(), (3, training_image_size, training_image_size))

### Training Setup

In [54]:
criterion = nn.MSELoss().type(dtype)
optimiser = optim.Adam(net.parameters(), lr=lr)
net = net.to(device)

In [55]:
#graph network error
def graph_error(error_list, name):
    if name[-4:] != ".png":
        if name != "":
            raise Exception("Suffix of file type is needed")
    save_dir = "Losses/" + name
    x = np.arange(len(error_list))
    y = np.asarray(error_list)
    plt.plot(x, y)
    plt.ylabel("Error")
    plt.xlabel("Epoches")
    if name != "":
        plt.savefig(save_dir)
    plt.show()

### Network Training

In [56]:
def network_training(net, optimiser, criterion, loadCkp = False, loadBest=True, new_dataset=False):
    error_list = []
    start_epoch = 0
    best_model_saved = False
    ckp_saved = False
    previous_batch_error = INT_MAX #initialise to a large value
    best_error = INT_MAX
    ###load checkpoint if required
    if loadCkp and loadBest:
        best_model_saved = True
        ckp_saved = True
        #when training on a new dataset for the first time, we only load the network itself
        if new_dataset:
            net, _, _, _, _ = load_ckp(ckp_pth+"/best_model.pt", net, optimiser)
            print("Finished loading the best model, ignored the training history")
        else:
            net, optimiser, start_epoch, error_list, best_error = load_ckp(ckp_pth+"/best_model.pt", net, optimiser)
            print("Finished loading the best model")
            previous_batch_error = best_error
    elif loadCkp and not loadBest:
        ckp_saved = True
        if new_dataset:
            net, _, _, _, _ = load_ckp(ckp_pth+"/checkpoint.pt", net, optimiser)
            print("Finished loading the checkpoint, ignored the training history")
        else:
            net, optimiser, start_epoch, error_list, best_error = load_ckp(ckp_pth+"/checkpoint.pt", net, optimiser)
            print("Finished loading the checkpoint")
            previous_batch_error = best_error
    
    if best_error == None:
        best_error = INT_MAX
    
    for epoch in range(start_epoch, epochs):
        batch_error = 0
        epoch_start_time = time.time()
        
        ### iterate through the batches
        for i, data in enumerate(train_loader, 0):
            optimiser.zero_grad()
            target = data[0].cuda()
            net_input = data[1].cuda()
            net_output = net(net_input)
            loss = criterion(net_output, target)
            batch_error += loss.item()
            loss.backward()
            optimiser.step()
        
        ### find one epoch training time
        one_epoch_time = time.time() - epoch_start_time
        print("One epoch time: " + str(one_epoch_time))
        
        ### process the error information
        print('[%d] loss: %.3f' %(epoch + 1, batch_error))
        ### if error is too large, roll back, otherwise save and continue
        if batch_error > error_tolerence*previous_batch_error and (best_model_saved or ckp_saved):
            if ckp_saved:
                print("Current error is too large, loading the last checkpoint")
                net, optimiser, start_epoch, error_list, best_psnr = \
                    load_ckp(ckp_pth+"/checkpoint.pt", net, optimiser)
            elif best_model_saved:
                print("Current error is too large, loading the best model")
                net, optimiser, start_epoch, error_list, best_psnr = \
                    load_ckp(ckp_pth+"/best_model.pt", net, optimiser)
            else:
                raise Exception("Error is too large, but no models to load")
        else:
            if batch_error > error_tolerence*previous_batch_error:
                print("Current error is too large, but cannot roll back")
            else:
                previous_batch_error = batch_error
                
            error_list.append(batch_error)
            ###save the latest model
            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimiser.state_dict(),
                'error_list': error_list,
                'best_error': best_error
            }
            save_ckp(checkpoint, False, ckp_pth)
            ckp_saved = True
            
            ###if error is the smallest save it as the best model
            if batch_error < best_error:
                best_error = batch_error
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optimiser.state_dict(),
                    'error_list': error_list,
                    'best_error': best_error
                }
                save_ckp(checkpoint, True, ckp_pth)
                best_model_saved = True
                print("New Minimum Error Recorded!")
                
            if ((epoch+1) % error_plot_freq) == 0 or epoch == epochs-1:
                graph_error(error_list[1:], "")