# Imports

In [1]:
# general
import time
from datetime import datetime
import io
import numpy as np
import matplotlib.pyplot as plt
import PIL
from tqdm import tqdm
import wandb
from pynvml import *
import kornia

# torch
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# custom imports
from lanczos import lanczos_2d as lanczos
from model.shiftnet_losses import define_loss
from utils.helper_functions import get_lr
from utils.helper_functions import plot_tensors
from utils.helper_functions import plot_tensors_extra_info
from utils.helper_functions import minmax_percentile
from utils.dataloader_spot import dataset_spot6

# Sesure DataLoader
from utils.dataloader import Dataset as dataset

# model
from model.shiftnet import ShiftNet

In [3]:
from lanczos.lanczos_2d import lanczos_shift
from lanczos.lanczos_2d import lanczos_kernel
from lanczos import lanczos_2d as lanczos
from model.shiftnet import get_thetas
from model.shiftnet import apply_shifts
from model.shiftnet import get_shift_loss
from model.shiftnet import ShiftNet
from model.shiftnet_losses import define_loss

# 0. Config

In [None]:
config = {
    "lr":0.00001,
    "batch_size":8,
    "epochs":100,
    "shift_factor":round((1/300)*10,3), #10 stands for pixels - max shift in factor
    "shiftnet_loss_relative":False,
    
    "data_loader":{
        "batch_size":8,
        "num_workers":4,
        "prefetch_factor":8,
        },
    
    "logging":{
        # Logging Settings
        "log_freq":1, # in it
        "image_freq":10, # in it
        "wandb_project":"Siamese_ShiftNet",
        "wandb_entity":"simon-donike",
        }
}

In [23]:
dataset_train = dataset_spot6("data/train/")
train_loader = DataLoader(dataset_train,
                          batch_size=config["data_loader"]["batch_size"],
                          shuffle=True,pin_memory=True,drop_last=True,
                          num_workers=config["data_loader"]["num_workers"],
                          prefetch_factor=config["data_loader"]["prefetch_factor"]) # prefetch 32

In [24]:
def plot_tensor(t,title=""):
    t_ = t.clone()
    t_ = t_[0].cpu().detach().numpy().transpose(1,2,0)
    t_ = minmax_percentile(t_)
    plt.imshow(t_)
    plt.title(title)
    plt.show()
    
def shifter(im,shift_factor=0.05):
    affinator = torchvision.transforms.Compose([torchvision.transforms.RandomAffine(degrees=0, translate=(shift_factor,shift_factor), scale=None, shear=None,
                                    interpolation=torchvision.transforms.InterpolationMode.NEAREST, fill=0 )]) # center=None
    im = affinator(im)
    return(im)

def plot_siamese_info(hr,sr,hr_small,sr_small,encoded_hr,encoded_sr,new_images,thetas,mode="return"):
    
    # transform images
    hr,sr,encoded_hr,encoded_sr = hr[0].cpu().detach().numpy().transpose(1,2,0),sr[0].cpu().detach().numpy().transpose(1,2,0),encoded_hr[0].cpu().detach().numpy().transpose(1,2,0),encoded_sr[0].cpu().detach().numpy().transpose(1,2,0)
    hr,sr,encoded_hr,encoded_sr = minmax_percentile(hr),minmax_percentile(sr),minmax_percentile(encoded_hr),minmax_percentile(encoded_sr)
    hr_small,sr_small = hr_small[0][0].cpu().detach().numpy(),sr_small[0][0].cpu().detach().numpy()
    hr_small,sr_small = minmax_percentile(hr_small),minmax_percentile(sr_small)
    
    new_images = new_images[0].cpu().detach().numpy().transpose(1,2,0)
    new_images = minmax_percentile(new_images)
    # prepare thetas
    values = thetas.detach().cpu()[0]
    
    
    # create image
    fig, axs = plt.subplots(2, 4,figsize=(20,10),facecolor='white')
    # plot images
    axs[0,0].imshow(hr)
    axs[0,0].set_title("HR")
    
    axs[1,0].imshow(sr)
    axs[1,0].set_title("SR")
    
    axs[0,1].imshow(encoded_hr)
    axs[0,1].set_title("Encoded HR")
    
    axs[1,1].imshow(encoded_sr)
    axs[1,1].set_title("Encoded SR")
    
    axs[0,2].imshow(hr_small)
    axs[0,2].set_title("Encoded HR ShiftNet window")
    
    axs[1,2].imshow(sr_small)
    axs[1,2].set_title("Encoded SR ShiftNet window")
    
    axs[1,3].imshow(new_images)
    axs[1,3].set_title("Shifted Image")
    
    # draw arrow
    axs[0,3].arrow(0,0, -1*values[0],-1*values[1],length_includes_head=True,width=0.2)
    axs[0,3].set_ylim(-10, 10) # set limits at 10 so they stay the same
    axs[0,3].set_xlim(-10, 10) # set limits at 10 so they stay the same
    axs[0,3].set_title("Pred. Shifts (px)\nX:  "+str(round(float(values[0]),2))+"\nY:  " + str(round(float(-1*values[1]),2)))
    axs[0,3].set_xticks([-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9])
    axs[0,3].set_yticks([-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9])
    axs[0,3].grid(alpha=0.4) # draw gridlines 
    if mode=="show":
        plt.show()
        return(None)
    if mode == "return":
        plt.close()
        return(fig)
    if mode=="log":
        # return wandb image dtype
        buf = io.BytesIO()
        fig.savefig(buf)
        buf.seek(0)
        im = PIL.Image.open(buf)
        image = wandb.Image(im, caption="Image")
        wandb.log({"image":image})
        plt.close()
        return(None)
        

In [25]:
class siamese_arm(nn.Module):
    ''' ShiftNet, a neural network for sub-pixel registration and interpolation with lanczos kernel. '''
    
    def __init__(self,in_channel=1):
        
        super(siamese_arm, self).__init__()

        self.layer1 = nn.Sequential(nn.Conv2d(in_channel, 32, 3, padding=1),nn.BatchNorm2d(32),nn.ReLU()) # nn.BatchNorm2d(128)
        self.layer2 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU())
        self.layer4 = nn.Sequential(nn.Conv2d(64, 32, 3, padding=1),nn.BatchNorm2d(32),nn.ReLU())
        self.layer5 = nn.Sequential(nn.Conv2d(32, 1, 3, padding=1),nn.BatchNorm2d(1),nn.ReLU())

        
    def forward(self, x):
        '''
        Args:
            in: (tensor, BxCxWxH): LR or SR image to be encoded before shift determination
            out: (tensor, BxCxWxH): encoded image
        '''
        
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        return out 
    


In [26]:
# Instanciate Siamese Models
arm_hr = siamese_arm(in_channel=3).to(device).train()
arm_sr = siamese_arm(in_channel=3).to(device).train()

# Instanciate Shiftnet
regis_model = ShiftNet(in_channel=1).to(device).train()
shiftnet_loss = define_loss("MAE")

# set up optimizer
# just add both arms + shiftnet
optimizer = torch.optim.Adam(list(arm_hr.parameters()) + list(arm_sr.parameters())+
                                list(regis_model.parameters()), lr=config["lr"])  # opt

In [None]:
# logging settings
run_name = "Siamese_ShiftNet_"+str(datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))
wandb.init(name=run_name,project=config["logging"]["wandb_project"], entity=config["logging"]["wandb_entity"],config=config)

#iterate over epochs
for epoch in range(config["epochs"]):
    epoch=epoch+1
    
    # iterate over datalaoder
    it=0
    for hr in tqdm(train_loader,ascii=True):
        it+=1
        hr = hr.to(device) # get HR image
        sr = shifter(hr).to(device) # generate shifted image

        # encode HR and LR
        encoded_hr = arm_hr(hr)
        encoded_sr = arm_hr(sr)

        # calculate predicted thetas
        #sr = encoded_sr
        #hr = encoded_hr
        thetas, hr_small, sr_small = get_thetas(encoded_hr,encoded_sr,regis_model,n_channels=sr.shape[1])
        # perform shift based on calculated thetas
        new_images,thetas = apply_shifts(sr,thetas,regis_model,n_channels=3)
        # calculate train loss according to defined function
        train_loss,hr_loss,new_images_loss = get_shift_loss(new_images,encoded_hr,shiftnet_loss,sr_small,hr_small,
                                                            relative_loss=config["shiftnet_loss_relative"])

        # train network
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        #break
        if it%config["logging"]["log_freq"]==0 and it!=0:
            # Log Shiftnet
            losses_shiftnet = {
                "General/train_loss":train_loss,
                "ShiftNet/mae":torch.nn.functional.l1_loss(new_images,hr),
                "ShiftNet/mse":torch.nn.functional.mse_loss(new_images,hr),
                "ShiftNet/ssim_loss":kornia.losses.ssim_loss(new_images,hr,window_size=5),
                "ShiftNet/psnr_loss":kornia.losses.psnr_loss(new_images,hr,max_val=1.0),
                "ShiftNet/lr":get_lr(optimizer)
                "ShiftNet/SR_encoded_mean":torch.mean(encoded_sr),
                "ShiftNet/HR_encoded_mean":torch.mean(encoded_hr)}
            # send to WandB
            wandb.log(losses_shiftnet)
        if it%config["logging"]["image_freq"]==0 and it!=0:
            plot_siamese_info(hr,sr,hr_small,sr_small,encoded_hr,encoded_sr,new_images,thetas,mode="log")
wandb.finish()

100%|##############################################################################| 1250/1250 [08:11<00:00,  2.54it/s]
100%|##############################################################################| 1250/1250 [08:14<00:00,  2.53it/s]
100%|##############################################################################| 1250/1250 [08:21<00:00,  2.49it/s]
100%|##############################################################################| 1250/1250 [08:29<00:00,  2.46it/s]
100%|##############################################################################| 1250/1250 [08:42<00:00,  2.39it/s]
100%|##############################################################################| 1250/1250 [08:29<00:00,  2.45it/s]
100%|##############################################################################| 1250/1250 [08:37<00:00,  2.41it/s]
100%|##############################################################################| 1250/1250 [07:49<00:00,  2.66it/s]
100%|###################################