In [None]:
import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2

from models import *
from watermarkDataset import WatermarkDataset
from vgg_loss import VGG19, VGG11
from utils import sample_images, val_loss, concat_images

import torch.nn as nn
import torch.nn.functional as F
import torch

import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# hyperparameters
lr = 3e-4
betas = (0.5, 0.999)
batch_size = 6
n_cpu = 6
start_epoch = 0
load = False
epochs = 500
sample_interval = 500
checkpoint_interval = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mean = (0.6349, 0.5809, 0.5312)
std = (0.3283, 0.3288, 0.3534)
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100 
lambda_vgg = 1000
# Calculate output of image discriminator (PatchGAN)
kernel_size=256
patch = (1, kernel_size // 2 ** 4, kernel_size // 2 ** 4)
unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size,kernel_size))

In [None]:
# transforms
train_transform = A.Compose([
        A.RandomCrop(width=kernel_size, height=kernel_size),
        #A.augmentations.transforms.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        ToTensorV2(),
        ],  
        additional_targets={'image1': 'image'}
    )

val_transform = A.Compose([
        #A.augmentations.transforms.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
        ToTensorV2(),
        ],  
        additional_targets={'image1': 'image'}
    )
inv_normalize = transforms.Normalize(
   mean= [-m/s for m, s in zip(mean, std)],
   std= [1/s for s in std]
)

In [None]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

In [None]:
# Initialize generator and discriminator
generator = GeneratorUNet().to(device)
discriminator = Discriminator().to(device)
if load:
    generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % ("watermarkdataset", start_epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % ("watermarkdataset", start_epoch)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=betas)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=betas)

In [None]:
# lr scheduler
lrs_G = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, "min", verbose=True)
lrs_D = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, "min", verbose=True)

In [None]:
# dataset and dataloader
train = WatermarkDataset("data/train", transform=train_transform)
trainloader = DataLoader(
    train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
    pin_memory=True,
)

val = WatermarkDataset("data/valid", transform=val_transform)
valloader = DataLoader(
    val,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
)

In [None]:
Tensor = torch.FloatTensor

In [None]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [None]:
vggloss = VGG11()
vggloss = vggloss.to(device)

In [None]:
prev_time = time.time()
if start_epoch == 0:
    start_epoch = -1
for epoch in range(start_epoch+1, epochs):
    loop = tqdm(trainloader)
    for i, batch in enumerate(loop):
        
        if i > 5:
            continue
        # Model inputs
        real_A = Variable(batch[0].type(Tensor)).to(device)
        real_B = Variable(batch[1].type(Tensor)).to(device)
        
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False).to(device)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False).to(device)

        # ------------------
        #  Train Generators
        # ------------------
        optimizer_G.zero_grad()
        with torch.cuda.amp.autocast():
            # GAN loss
            fake_B = generator(real_A)
            pred_fake = discriminator(fake_B, real_A)
            loss_GAN = criterion_GAN(pred_fake, valid)
            # Pixel-wise loss
            loss_pixel = criterion_pixelwise(fake_B, real_B)

            # Total loss
            #print(f"loss gan: {loss_GAN} loss pixel: {lambda_pixel * loss_pixel} loss vgg: {lambda_vgg*vggloss.calc_loss(fake_B, real_B)}")
            loss_G = loss_GAN + lambda_pixel * loss_pixel + lambda_vgg * vggloss.calc_loss(fake_B, real_B)
            
            g_scaler.scale(loss_G).backward()
            g_scaler.step(optimizer_G)
            g_scaler.update()
        
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        with torch.cuda.amp.autocast():
            # Real loss
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)

            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            d_scaler.scale(loss_D).backward()
            d_scaler.step(optimizer_D)
            d_scaler.update()
        
        
        # --------------
        #  Log Progress
        # --------------
        loop.set_description(f"epoch: {epoch} loss discriminator: {loss_D:.5f} loss generator: {loss_G:.5f}")
        

        # If at sample interval save image
        batches_done = epoch * len(trainloader) + i
        if batches_done % sample_interval == 0:
            sample_images(batches_done, 10, generator, valloader, Tensor, kernel_size, unfold, device)
          
    # check if lr should be updated
    val_loss_G, val_loss_D = val_loss(generator, discriminator, vggloss, valloader, Tensor, kernel_size, unfold, patch, criterion_GAN, criterion_pixelwise, lambda_pixel, lambda_vgg, device)
    print(f"epoch {epoch} with G and D val loss:", val_loss_G, val_loss_D)
    lrs_D.step(val_loss_D)
    lrs_G.step(val_loss_G)  
    
    # Save model checkpoints
    if epoch % checkpoint_interval == 0:
        torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % ("watermarkdataset", epoch))
        torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % ("watermarkdataset", epoch))