In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch import Tensor
from torch.utils.data import DataLoader, random_split
from torch import optim
from torch.optim import lr_scheduler


from utils.dataloader import Dataset
from model.model import Discriminator, Generator, ContentLoss

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Build Dataloader and Dataset

In [7]:
working_directory = "/home/simon/CDE_UBS/thesis/SR-GAN/SRGAN"
folder_path = "/home/simon/CDE_UBS/thesis/testing/data_f4/"
dataset_file = "/home/simon/CDE_UBS/thesis/testing/data_f4_pkls/df_saved_images.pkl"


dataset = Dataset(folder_path,dataset_file,transform="histogram_matching",sen2_amount=1,
                  sen2_tile="all",location="ubs")
loader_train  = DataLoader(dataset,batch_size=2, shuffle=True, num_workers=0,pin_memory=True,
                    drop_last=True,prefetch_factor=2)

# Implement Train Loop

In [8]:
def build_model(device) -> nn.Module:
    discriminator = Discriminator().to(device)
    generator = Generator().to(device)
    return discriminator, generator

def define_optimizer(discriminator: nn.Module, generator: nn.Module,model_lr,model_betas) -> [optim.Adam, optim.Adam]:
    d_optimizer = optim.Adam(discriminator.parameters(), model_lr, model_betas) # LR & model_betas
    g_optimizer = optim.Adam(generator.parameters(), model_lr, model_betas)     # LR & model_betas
    return d_optimizer, g_optimizer


def define_scheduler(d_optimizer: optim.Adam, g_optimizer: optim.Adam, epochs,optimizer_gamma) -> [lr_scheduler.StepLR, lr_scheduler.StepLR]:
    epochs=10
    d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 2, optimizer_gamma) # o.1 = optimizer gamma
    g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 2, optimizer_gamma) # o.1 = optimizer gamma
    return d_scheduler, g_scheduler

def define_loss(device) -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
    psnr_criterion = nn.MSELoss().to(device) # why psnr MSE?
    pixel_criterion = nn.MSELoss().to(device)
    content_criterion = ContentLoss().to(device)
    adversarial_criterion = nn.BCEWithLogitsLoss().to(device)
    return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion


In [9]:
# config
epochs = 10

# loss weights
pixel_weight = 1.0
content_weight = 1.0
adversarial_weight = 0.001

#optim config
model_lr = 1e-4
model_betas = (0.9, 0.999)

# scheduler config
optimizer_step_size = epochs // 2
optimizer_gamma = 0.1

# create models, optimizers, schedulers, losses
generator,discriminator  = build_model(device)
d_optimizer, g_optimizer = define_optimizer(discriminator, generator,model_lr,model_betas)
d_scheduler, g_scheduler = define_scheduler(d_optimizer, g_optimizer,epochs,optimizer_gamma)
psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss(device)

In [16]:
# training loop
for e in range(epochs):
    for (lr, hr) in loader_train:
        
        # set all models to train mode
        generator.train()
        discriminator.train()
        
        # send data to device
        lr = lr.to(device)
        hr = hr.to(device)
        
        # set sample labels of input for discriminator
        real_label = torch.full([lr.size(0), 1], 1.0, dtype=lr.dtype, device=device)
        fake_label = torch.full([lr.size(0), 1], 0.0, dtype=lr.dtype, device=device)
        
        # use generator to create super-resolution images
        sr = generator(lr)
        
        """
        start training of DISCRIMINATOR
        """
        # grad required for discriminator
        for p in discriminator.parameters():
            p.requires_grad = True
            
        # Initialize the discriminator optimizer gradient
        d_optimizer.zero_grad()
        
        # Calculate the loss of the discriminator on the high-resolution image
        hr_output = discriminator(hr)
        d_loss_hr = adversarial_criterion(hr_output, real_label)
        
        # calculate the loss of the discimonator on the super-resolution image
        sr_output = discriminator(sr.detach())
        d_loss_sr = adversarial_criterion(sr_output, fake_label)
        
        # Count discriminator total loss
        d_loss = d_loss_hr + d_loss_sr
        """
        End training DISCRIMINATOR
        ____________________________________
        
        Start training GENERATOR
        """
        # At this stage, the discriminator no needs to require a derivative gradient
        for p in discriminator.parameters():
            p.requires_grad = False

        # Initialize the generator optimizer gradient
        g_optimizer.zero_grad()
        
        # Calculate the loss of the generator on the super-resolution image
        output = discriminator(sr)
        pixel_loss = config.pixel_weight * pixel_criterion(sr, hr.detach()) # detaching hr from generator
        content_loss = config.content_weight * content_criterion(sr, hr.detach())
        adversarial_loss = config.adversarial_weight * adversarial_criterion(output, real_label)
        
        # Count discriminator total loss
        g_loss = pixel_loss + content_loss + adversarial_loss
        
        # # obtain the gradients with respect to the loss - generator
        g_loss.backward()
        
        # Update generator parameters, perform one step of gradient descent
        g_optimizer.step()
        # scaler.update() waht does this do
        
        """
        End training GENERATOR
        """
        
        # Calculate the scores of the two images on the discriminator
        d_hr_probability = torch.sigmoid(torch.mean(hr_output))
        d_sr_probability = torch.sigmoid(torch.mean(sr_output))
        
        # measure accuracy and record loss
        psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
        pixel_losses.update(pixel_loss.item(), lr.size(0))
        content_losses.update(content_loss.item(), lr.size(0))
        adversarial_losses.update(adversarial_loss.item(), lr.size(0))
        d_hr_probabilities.update(d_hr_probability.item(), lr.size(0))
        d_sr_probabilities.update(d_sr_probability.item(), lr.size(0))
        psnres.update(psnr.item(), lr.size(0))
        
        break



RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x12800 and 18432x1024)