In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import glob

import torch
from torchvision import transforms
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset

import torch.nn.functional as F
from torch.autograd import Variable
import math
from math import exp
from tqdm import tqdm

from kornia.filters.sobel import Sobel
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from dataset import DEMDataset
from model import *
from utils import calculatePSNR

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Training on device:", device)

### creating the dataloader

In [None]:
train_loader = DataLoader(DEMDataset(load_dir = '/kaggle/input/demdataset8020/train/', 
                                  transform = transforms.Compose([transforms.ToTensor()]))
                            ,batch_size=16, shuffle=True)
test_loader = DataLoader(DEMDataset(load_dir = '/kaggle/input/demdataset8020/test/',
                                transform = transforms.Compose([transforms.ToTensor()]))
                                ,batch_size=8, shuffle=False)

### visualising the dataset

In [None]:
for hr, lr in train_loader:
    sobel = Sobel()
    edges = sobel(hr)
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(36, 12))
    axes[0].set_yticklabels([])
    axes[0].imshow(hr[0].squeeze(), cmap='gray')
    axes[1].imshow(lr[0].squeeze(), cmap='gray')
    axes[2].imshow(edges[0].squeeze(), cmap='gray')
    break

### declaring models

In [None]:
generator = MobileSR()
generator = generator.to(device)
discriminator = Discriminator(1, 128)
discriminator = discriminator.to(device)

In [None]:
class gradientAwareLoss(nn.Module): 
    def __init__(self):
        super().__init__()
        self.sobelFilter = Sobel().to('cuda')
        self.l1Loss = nn.L1Loss().to('cuda')

    def forward(self, hr, sr):
        hrEdgeMap = self.sobelFilter(hr)
        srEdgeMap = self.sobelFilter(sr)
        return self.l1Loss(hrEdgeMap, srEdgeMap) 

In [None]:
l1Loss = nn.L1Loss().to(device) 
edgeLoss = gradientAwareLoss().to(device) 
ssim = SSIM().to(device) 
anotherl1Loss = nn.L1Loss().to(device) 

optim_G = torch.optim.Adam(generator.parameters(), lr=0.00001)
optim_D = torch.optim.Adam(discriminator.parameters(), lr=0.00001)

In [None]:
import gc
from torchvision.utils import make_grid
import wandb

In [None]:
wandb.init(project="SRTGAN-Bic", name="exp2")

In [None]:
num_epochs = 100
num_train_batches = float(len(train_loader))
num_val_batches = float(len(test_loader))

def train_one_epoch(epoch):
    print(f"Epoch {epoch}: ", end ="")
    
    l1_loss_per_epoch = 0.0
    edge_loss_per_epoch = 0.0
    ssim_loss_per_epoch = 0.0
    ssim_per_epoch = 0.0
    psnr_per_epoch = 0.0
    total_loss_per_epoch = 0.0
    D_adv_loss = 0
    
    generator.train()
    for batch, (hr, lr) in enumerate(tqdm(train_loader)):

        for p in discriminator.parameters():
            p.requires_grad = False
        #training generator
        optim_G.zero_grad()
 
        lr_images = lr.to(device)
        hr_images = hr.to(device)
        lr_images = lr_images.float()
        predicted_hr_images = generator(lr_images)
        predicted_hr_labels = discriminator(predicted_hr_images)
        gf_loss = F.binary_cross_entropy_with_logits(predicted_hr_labels, torch.ones_like(predicted_hr_labels)) #adverserial loss
      
      
        # reconstruction loss    
      
        l1_loss_per_sample = l1Loss(hr_images*1000, predicted_hr_images*1000)
        ssim_per_sample = ssim(hr_images, predicted_hr_images)
        ssim_loss_per_sample = 1 - ssim_per_sample
        edge_loss = edgeLoss(hr_images*1000, predicted_hr_images*1000)  
        reconstruction_loss = l1_loss_per_sample + 100*(ssim_loss_per_sample) + 50*edge_loss
        t_loss = reconstruction_loss + 50*gf_loss
        
      
        t_loss.backward()
        optim_G.step()
      
        psnr_per_sample = calculatePSNR(hr_images.detach().cpu().numpy(), predicted_hr_images.detach().cpu().numpy())
    
        l1_loss_per_epoch += l1_loss_per_sample.item()
        edge_loss_per_epoch += edge_loss.item() 
        ssim_loss_per_epoch += ssim_loss_per_sample.item() 
        ssim_per_epoch += ssim_per_sample.item()
        psnr_per_epoch += psnr_per_sample 
        total_loss_per_epoch += t_loss.item()
      
        # training discriminator
        for p in discriminator.parameters():
            p.requires_grad = True
        optim_D.zero_grad()
        predicted_hr_images = generator(lr_images).detach() # avoid back propogation to generator
        hr_images = hr_images.float()
        adv_hr_real = discriminator(hr_images)
        adv_hr_fake = discriminator(predicted_hr_images)
        df_loss = F.binary_cross_entropy_with_logits(adv_hr_real, torch.ones_like(adv_hr_real)) + F.binary_cross_entropy_with_logits(adv_hr_fake, torch.zeros_like(adv_hr_fake))
        D_adv_loss += df_loss.item()
        df_loss.backward()
        optim_D.step()
    
    l1_loss_per_epoch /= float(len(train_loader))
    edge_loss_per_epoch /= float(len(train_loader))
    ssim_loss_per_epoch /= float(len(train_loader))
    ssim_per_epoch /= float(len(train_loader))
    psnr_per_epoch /= float(len(train_loader))
    total_loss_per_epoch /= float(len(train_loader))
    
    wandb.log({"Train L1 Loss": l1_loss_per_epoch})
    wandb.log({"Train Edge Loss": edge_loss_per_epoch})
    wandb.log({"Train SSIM Loss": ssim_loss_per_epoch})
    wandb.log({"Train Total Loss": total_loss_per_epoch})
    wandb.log({"Train SSIM": ssim_per_epoch})
    wandb.log({"Train PSNR": psnr_per_epoch})
    
    print(f"(Train) L1 Loss: {l1_loss_per_epoch:.3f} | SSIM Loss: {ssim_loss_per_epoch:.3f} | Edge Loss: {edge_loss_per_epoch:.3f} | Total Loss: {total_loss_per_epoch:.3f}")
    print(f"SSIM: {ssim_per_epoch:.3f} | PSNR: {psnr_per_epoch}")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    
    return psnr_per_epoch

In [None]:
def valid_one_epoch(epoch):
    ssim_per_epoch = 0.0
    psnr_per_epoch = 0.0
    b_ssim_per_epoch = 0.0
    b_psnr_per_epoch = 0.0
    
    generator.eval()
    with torch.no_grad():
        for hr, lr in tqdm(test_loader):
            batched_hr, batched_lr = hr.to(device), lr.to(device)
            predicted_sr = generator(batched_lr)
                
            bilinear_sr = F.interpolate(batched_lr, scale_factor=2, mode='bilinear')
            

            ssim_per_epoch += ssim(batched_hr, predicted_sr)
            psnr_per_epoch += calculatePSNR(batched_hr.cpu().numpy(), predicted_sr.cpu().numpy())

            b_ssim_per_epoch += ssim(batched_hr, bilinear_sr)
            b_psnr_per_epoch += calculatePSNR(batched_hr.cpu().numpy(), bilinear_sr.cpu().numpy())

            grid1 = make_grid(batched_lr[:4])
            grid2 = make_grid(batched_hr[:4])
            grid3 = make_grid(predicted_sr[:4])
            grid4 = make_grid(bilinear_sr[:4])

            grid1 = wandb.Image(grid1, caption="Low Resolution DEM")
            grid2 = wandb.Image(grid2, caption="High Resolution DEM")
            grid3 = wandb.Image(grid3, caption="Reconstructed High Resolution DEM")
            grid4 = wandb.Image(grid4, caption="Bilinear High Resolution DEM")
            
            wandb.log({"Original LR": grid1})
            wandb.log({"Original HR": grid2})
            wandb.log({"Reconstruced": grid3})
            wandb.log({"Bilinear": grid4})

        ssim_per_epoch /= float(len(test_loader))
        psnr_per_epoch /= float(len(test_loader))
        b_ssim_per_epoch /= float(len(test_loader))
        b_psnr_per_epoch /= float(len(test_loader))

        wandb.log({"Test Predicted SSIM": ssim_per_epoch})
        wandb.log({"Test Predicted PSNR": psnr_per_epoch})
        wandb.log({"Bilinear SSIM": b_ssim_per_epoch})
        wandb.log({"Bilinear PSNR": b_psnr_per_epoch})

        print(f"(Val) SSIM: {ssim_per_epoch:.3f} | PSNR: {psnr_per_epoch:.3f}")
        print(f"(Bil) SSIM: {b_ssim_per_epoch:.3f} | PSNR: {b_psnr_per_epoch:.3f}")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return psnr_per_epoch

In [None]:
best_psnr = 0
count = 0
prev_psnr =0
for i in range(100):
    torch.cuda.empty_cache()
    gc.collect()
    train_psnr = train_one_epoch(i)
    valid_psnr = valid_one_epoch(i)
    
    if valid_psnr >= prev_psnr:
        count =0
    else :
        count +=1
        if count ==5 :
            generator = generator.load_state_dict(torch.load(f"best_model_{best_psnr}.pt"))
    
    if valid_psnr > best_psnr:
        best_psnr = valid_psnr
        torch.save(generator.state_dict(), f"best_model_{best_psnr}.pt")