### Imports

In [None]:
# import torch
# import torch.nn as nn
# import torch.cuda as cuda
# import torch.optim as optim
# from torchvision import datasets, transforms
# from timm.models.layers import trunc_normal_
# from torch.utils.data import DataLoader
# from torchsummary import summary

# import numpy as np
# import cv2 as cv
# from PIL import Image
# import time
# import warnings
# import time

from model import DeepRecursiveTransformer
# from my_utils import batch_PSNR, batch_SSIM, output_to_image
# from my_utils import save_ckp, load_ckp, base_path

### Global Parameters

In [3]:
training_image_size = 56
dtype = torch.cuda.FloatTensor
batch_size = 5
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)
epochs = 4600
lr = 0.0001
error_plot_freq = 20
INT_MAX = 2147483647
error_tolerence = 10

#paths
base_pth = base_path()
ckp_pth = base_pth + "/pretrained"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Load Data

In [49]:
### Prepare Data for Training
# train_dataset = Rain800TrainData(training_image_size, dataset_dir='/Rain-800/') #/Rain100L-Train/
# train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=False)

Loading Test1200 from Hi-Net Data


In [9]:
dim = 96
input_shape = (training_image_size, training_image_size)
patch_size = 1
residual_depth = 3
recursive_depth = 6
net = DeepRecursiveTransformer(dim, input_shape, patch_size, residual_depth, recursive_depth)
# summary(net.cuda(), (3, training_image_size, training_image_size))

### Training Setup

In [54]:
criterion = nn.MSELoss().type(dtype)
optimiser = optim.Adam(net.parameters(), lr=lr)
net = net.to(device)

In [55]:
#graph network error
def graph_error(error_list, name):
    if name[-4:] != ".png":
        if name != "":
            raise Exception("Suffix of file type is needed")
    save_dir = "Losses/" + name
    x = np.arange(len(error_list))
    y = np.asarray(error_list)
    plt.plot(x, y)
    plt.ylabel("Error")
    plt.xlabel("Epoches")
    if name != "":
        plt.savefig(save_dir)
    plt.show()

### Network Training

In [56]:
def network_training(net, optimiser, criterion, loadCkp = False, loadBest=True, new_dataset=False):
    error_list = []
    start_epoch = 0
    best_model_saved = False
    ckp_saved = False
    previous_batch_error = INT_MAX #initialise to a large value
    best_error = INT_MAX
    ###load checkpoint if required
    if loadCkp and loadBest:
        best_model_saved = True
        ckp_saved = True
        #when training on a new dataset for the first time, we only load the network itself
        if new_dataset:
            net, _, _, _, _ = load_ckp(ckp_pth+"/best_model.pt", net, optimiser)
            print("Finished loading the best model, ignored the training history")
        else:
            net, optimiser, start_epoch, error_list, best_error = load_ckp(ckp_pth+"/best_model.pt", net, optimiser)
            print("Finished loading the best model")
            previous_batch_error = best_error
    elif loadCkp and not loadBest:
        ckp_saved = True
        if new_dataset:
            net, _, _, _, _ = load_ckp(ckp_pth+"/checkpoint.pt", net, optimiser)
            print("Finished loading the checkpoint, ignored the training history")
        else:
            net, optimiser, start_epoch, error_list, best_error = load_ckp(ckp_pth+"/checkpoint.pt", net, optimiser)
            print("Finished loading the checkpoint")
            previous_batch_error = best_error
    
    if best_error == None:
        best_error = INT_MAX
    
    for epoch in range(start_epoch, epochs):
        batch_error = 0
        epoch_start_time = time.time()
        
        ### iterate through the batches
        for i, data in enumerate(train_loader, 0):
            optimiser.zero_grad()
            target = data[0].cuda()
            net_input = data[1].cuda()
            net_output = net(net_input)
            loss = criterion(net_output, target)
            batch_error += loss.item()
            loss.backward()
            optimiser.step()
        
        ### find one epoch training time
        one_epoch_time = time.time() - epoch_start_time
        print("One epoch time: " + str(one_epoch_time))
        
        ### process the error information
        print('[%d] loss: %.3f' %(epoch + 1, batch_error))
        ### if error is too large, roll back, otherwise save and continue
        if batch_error > error_tolerence*previous_batch_error and (best_model_saved or ckp_saved):
            if ckp_saved:
                print("Current error is too large, loading the last checkpoint")
                net, optimiser, start_epoch, error_list, best_psnr = \
                    load_ckp(ckp_pth+"/checkpoint.pt", net, optimiser)
            elif best_model_saved:
                print("Current error is too large, loading the best model")
                net, optimiser, start_epoch, error_list, best_psnr = \
                    load_ckp(ckp_pth+"/best_model.pt", net, optimiser)
            else:
                raise Exception("Error is too large, but no models to load")
        else:
            if batch_error > error_tolerence*previous_batch_error:
                print("Current error is too large, but cannot roll back")
            else:
                previous_batch_error = batch_error
                
            error_list.append(batch_error)
            ###save the latest model
            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimiser.state_dict(),
                'error_list': error_list,
                'best_error': best_error
            }
            save_ckp(checkpoint, False, ckp_pth)
            ckp_saved = True
            
            ###if error is the smallest save it as the best model
            if batch_error < best_error:
                best_error = batch_error
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optimiser.state_dict(),
                    'error_list': error_list,
                    'best_error': best_error
                }
                save_ckp(checkpoint, True, ckp_pth)
                best_model_saved = True
                print("New Minimum Error Recorded!")
                
            if ((epoch+1) % error_plot_freq) == 0 or epoch == epochs-1:
                graph_error(error_list[1:], "")