In [1]:
import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2

from models import *
from watermarkDataset import WatermarkDataset
from vgg_loss import VGG19, VGG11

import torch.nn as nn
import torch.nn.functional as F
import torch

import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# remove normalization and check results then
# normalization makes images grey?

In [3]:
# hyperparameters
lr = 3e-4
betas = (0.5, 0.999)
batch_size = 6
n_cpu = 6
start_epoch = 100
load = True
epochs = 500
sample_interval = 500
checkpoint_interval = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mean = (0.6349, 0.5809, 0.5312)
std = (0.3283, 0.3288, 0.3534)
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100 # was 100 orig
lambda_vgg = 1000 # orig 1000
# Calculate output of image discriminator (PatchGAN)
patch = (1, 256 // 2 ** 4, 256 // 2 ** 4)
kernel_size=256
unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size,kernel_size))

In [4]:
# transforms
train_transform = A.Compose([
        A.RandomCrop(width=kernel_size, height=kernel_size),
        #A.augmentations.transforms.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        ToTensorV2(),
        ],  
        additional_targets={'image1': 'image'}
    )

val_transform = A.Compose([
        #A.augmentations.transforms.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
        ToTensorV2(),
        ],  
        additional_targets={'image1': 'image'}
    )
inv_normalize = transforms.Normalize(
   mean= [-m/s for m, s in zip(mean, std)],
   std= [1/s for s in std]
)

In [5]:
# helper functions
def val_loss():
    """ BATCHSIZE SHOULD BE 1 """
    generator.eval()
    discriminator.eval()
    vggloss.eval()
    
    total_loss_G = torch.zeros([1], dtype=torch.float).to(device)
    total_loss_D = torch.zeros([1], dtype=torch.float).to(device)
    loop = tqdm(valloader)
    for i, batch in enumerate(loop):
        # Model inputs
        real_A = Variable(batch[0].type(Tensor))
        real_B = Variable(batch[1].type(Tensor))
        
        # reshape full images to patches of 256x256
        _,_,w,h = real_A.shape
        real_A = transforms.functional.pad(real_A, [0, 0, int((int(h/kernel_size)+1)*kernel_size)-h, int((int(w/kernel_size)+1)*kernel_size)-w])
        real_A = unfold(real_A).squeeze(0).T.view(-1,3,kernel_size,kernel_size)
        _,_,w,h = real_A.shape
        real_B = transforms.functional.pad(real_B, [0, 0, int((int(h/kernel_size)+1)*kernel_size)-h, int((int(w/kernel_size)+1)*kernel_size)-w])
        real_B = unfold(real_B).squeeze(0).T.view(-1,3,kernel_size,kernel_size)
        

        # split data up in single images to fit in memory
        b = real_A.shape[0]
        t_loss_G = torch.zeros([1], dtype=torch.float).to(device)
        t_loss_D = torch.zeros([1], dtype=torch.float).to(device)
        
        for k in range(b):
            real_A_ = real_A[k].clone().unsqueeze(0).to(device)
            real_B_ = real_B[k].clone().unsqueeze(0).to(device)
            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A_.size(0), *patch))), requires_grad=False).to(device)
            fake = Variable(Tensor(np.zeros((real_A_.size(0), *patch))), requires_grad=False).to(device)
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    # Gen
                    fake_B = generator(real_A_)
                    pred_fake = discriminator(fake_B, real_A_)
                    loss_GAN = criterion_GAN(pred_fake, valid)
                    # Pixel-wise loss
                    loss_pixel = criterion_pixelwise(fake_B, real_B_)
                    # Total loss
                    loss_G = loss_GAN + lambda_pixel * loss_pixel + lambda_vgg * vggloss.calc_loss(fake_B, real_B_)
                    t_loss_G += loss_G

                    # Disc
                    pred_real = discriminator(real_B_, real_A_)
                    loss_real = criterion_GAN(pred_real, valid)
                    # Fake loss
                    pred_fake = discriminator(fake_B.detach(), real_A_)
                    loss_fake = criterion_GAN(pred_fake, fake)
                    # Total loss
                    loss_D = 0.5 * (loss_real + loss_fake)
                    t_loss_D += loss_D
                    
        t_loss_G = t_loss_G/b
        t_loss_D = t_loss_D/b
        total_loss_G += t_loss_G
        total_loss_D += t_loss_D
                
    total_loss_G = total_loss_G/len(valloader)
    total_loss_D = total_loss_D/len(valloader)
    
    generator.train()
    discriminator.train()
    vggloss.train()
    return total_loss_G,total_loss_D
        
def sample_images(batches_done, N):
    """ Save the first N generated images from the dataset """
    generator.eval()
    generator.to("cpu")
    
    imgs = []
    for i, batch in enumerate(valloader):
        if i >= N:
            break
        t = []
        real_A = Variable(batch[0].type(Tensor))
        real_B = Variable(batch[1].type(Tensor))
        t.append(real_A.detach())
        
        _,_,w,h = real_A.shape
        real_A = transforms.functional.pad(real_A, [0, 0, int((int(h/kernel_size)+1)*kernel_size)-h, int((int(w/kernel_size)+1)*kernel_size)-w])
        real_A = unfold(real_A).squeeze(0).T.view(-1,3,kernel_size,kernel_size)
        fake_B = generator(real_A)
        
        new_w = int((int(w/kernel_size)+1)*kernel_size)
        new_h = int((int(h/kernel_size)+1)*kernel_size)
        fake_B = fake_B.permute(1,2,3,0)
        fake_B = fake_B.reshape(1,3*kernel_size*kernel_size,fake_B.shape[3])
        fold = nn.Fold((new_w,new_h), (kernel_size,kernel_size), stride=(kernel_size,kernel_size))
        fake_B = fold(fake_B)
        
        fake_B = fake_B[:, :, 0:w, 0:h]
        t.append(fake_B.detach())
        t.append(real_B.detach())
        imgs.append(tuple(t))
    
    os.mkdir(f"images/watermarkdataset/{batches_done}")
    for j, (a,fb,b) in enumerate(imgs):
        img_sample = torch.cat((a, fb, b), -2)
        save_image(img_sample, f"images/watermarkdataset/{batches_done}/{j}.jpg")
    
    generator.to(device)
    generator.train()
    return imgs

In [6]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

In [7]:
# Initialize generator and discriminator
generator = GeneratorUNet().to(device)
discriminator = Discriminator().to(device)
if load:
    generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % ("watermarkdataset", start_epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % ("watermarkdataset", start_epoch)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

In [8]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=3e-7, betas=betas)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=3e-5, betas=betas)

In [9]:
# lr scheduler
lrs_G = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, "min", verbose=True)
lrs_D = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, "min", verbose=True)

In [10]:
# dataset and dataloader
train = WatermarkDataset("data/train", transform=train_transform)
trainloader = DataLoader(
    train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
    pin_memory=True,
)

val = WatermarkDataset("data/valid", transform=val_transform)
valloader = DataLoader(
    val,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
)

In [11]:
Tensor = torch.FloatTensor

In [12]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [13]:
vggloss = VGG11()
vggloss = vggloss.to(device)

In [14]:
prev_time = time.time()
if start_epoch == 0:
    start_epoch = -1
for epoch in range(start_epoch+1, epochs):
    loop = tqdm(trainloader)
    for i, batch in enumerate(loop):
        
        # Model inputs
        real_A = Variable(batch[0].type(Tensor)).to(device)
        real_B = Variable(batch[1].type(Tensor)).to(device)
        
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False).to(device)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False).to(device)

        # ------------------
        #  Train Generators
        # ------------------
        optimizer_G.zero_grad()
        with torch.cuda.amp.autocast():
            # GAN loss
            fake_B = generator(real_A)
            pred_fake = discriminator(fake_B, real_A)
            loss_GAN = criterion_GAN(pred_fake, valid)
            # Pixel-wise loss
            loss_pixel = criterion_pixelwise(fake_B, real_B)

            # Total loss
            #print(f"loss gan: {loss_GAN} loss pixel: {lambda_pixel * loss_pixel} loss vgg: {lambda_vgg*vggloss.calc_loss(fake_B, real_B)}")
            loss_G = loss_GAN + lambda_pixel * loss_pixel + lambda_vgg * vggloss.calc_loss(fake_B, real_B)
            
            g_scaler.scale(loss_G).backward()
            g_scaler.step(optimizer_G)
            g_scaler.update()
        
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        with torch.cuda.amp.autocast():
            # Real loss
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)

            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            d_scaler.scale(loss_D).backward()
            d_scaler.step(optimizer_D)
            d_scaler.update()
        
        
        # --------------
        #  Log Progress
        # --------------
        loop.set_description(f"epoch: {epoch} loss discriminator: {loss_D:.5f} loss generator: {loss_G:.5f}")
        

        # If at sample interval save image
        batches_done = epoch * len(trainloader) + i
        if batches_done % sample_interval == 0:
            sample_images(batches_done, 10)
          
    # check if lr should be updated
    val_loss_G, val_loss_D = val_loss()
    print(f"epoch {epoch} with G and D val loss:", val_loss_G, val_loss_D)
    lrs_D.step(val_loss_D)
    lrs_G.step(val_loss_G)  
    
    # Save model checkpoints
    if epoch % checkpoint_interval == 0:
        torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % ("watermarkdataset", epoch))
        torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % ("watermarkdataset", epoch))

epoch: 74 loss discriminator: 0.00092 loss generator: 3.23484: 100%|███████████████| 3500/3500 [47:30<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 74 with G and D val loss: tensor([2.7406], device='cuda:0') tensor([0.0220], device='cuda:0')


epoch: 75 loss discriminator: 0.00321 loss generator: 3.12376: 100%|███████████████| 3500/3500 [47:45<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 75 with G and D val loss: tensor([2.7480], device='cuda:0') tensor([0.0234], device='cuda:0')


epoch: 76 loss discriminator: 0.05467 loss generator: 3.37285: 100%|███████████████| 3500/3500 [47:35<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 76 with G and D val loss: tensor([2.6764], device='cuda:0') tensor([0.0260], device='cuda:0')


epoch: 77 loss discriminator: 0.01607 loss generator: 1.77585: 100%|███████████████| 3500/3500 [47:33<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 77 with G and D val loss: tensor([2.7538], device='cuda:0') tensor([0.0218], device='cuda:0')


epoch: 78 loss discriminator: 0.00112 loss generator: 4.63670: 100%|███████████████| 3500/3500 [47:32<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 78 with G and D val loss: tensor([2.7395], device='cuda:0') tensor([0.0222], device='cuda:0')


epoch: 79 loss discriminator: 0.00115 loss generator: 2.85737: 100%|███████████████| 3500/3500 [47:35<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 79 with G and D val loss: tensor([2.7507], device='cuda:0') tensor([0.0263], device='cuda:0')


epoch: 80 loss discriminator: 0.00403 loss generator: 3.80876: 100%|███████████████| 3500/3500 [47:39<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 80 with G and D val loss: tensor([2.7690], device='cuda:0') tensor([0.0213], device='cuda:0')


epoch: 81 loss discriminator: 0.00110 loss generator: 2.60502: 100%|███████████████| 3500/3500 [47:42<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:52<00:00,  1.85s/it]


epoch 81 with G and D val loss: tensor([2.7033], device='cuda:0') tensor([0.0232], device='cuda:0')


epoch: 82 loss discriminator: 0.00902 loss generator: 1.93335: 100%|███████████████| 3500/3500 [47:41<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:54<00:00,  1.85s/it]


epoch 82 with G and D val loss: tensor([2.7166], device='cuda:0') tensor([0.0215], device='cuda:0')


epoch: 83 loss discriminator: 0.00317 loss generator: 2.13851: 100%|███████████████| 3500/3500 [47:47<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:53<00:00,  1.85s/it]


epoch 83 with G and D val loss: tensor([2.7282], device='cuda:0') tensor([0.0215], device='cuda:0')


epoch: 84 loss discriminator: 0.00281 loss generator: 2.12442: 100%|███████████████| 3500/3500 [47:44<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:54<00:00,  1.85s/it]


epoch 84 with G and D val loss: tensor([2.7227], device='cuda:0') tensor([0.0240], device='cuda:0')


epoch: 85 loss discriminator: 0.00192 loss generator: 1.90216: 100%|███████████████| 3500/3500 [47:45<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 85 with G and D val loss: tensor([2.7361], device='cuda:0') tensor([0.0234], device='cuda:0')


epoch: 86 loss discriminator: 0.08412 loss generator: 1.47916: 100%|███████████████| 3500/3500 [47:45<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:56<00:00,  1.86s/it]


epoch 86 with G and D val loss: tensor([2.7320], device='cuda:0') tensor([0.0252], device='cuda:0')


epoch: 87 loss discriminator: 0.00321 loss generator: 2.42090: 100%|███████████████| 3500/3500 [48:02<00:00,  1.21it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:54<00:00,  1.85s/it]


epoch 87 with G and D val loss: tensor([2.6629], device='cuda:0') tensor([0.0288], device='cuda:0')


epoch: 88 loss discriminator: 0.00139 loss generator: 2.81890: 100%|███████████████| 3500/3500 [47:48<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:54<00:00,  1.85s/it]


epoch 88 with G and D val loss: tensor([2.7652], device='cuda:0') tensor([0.0199], device='cuda:0')


epoch: 89 loss discriminator: 0.00182 loss generator: 2.72431: 100%|███████████████| 3500/3500 [47:26<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 89 with G and D val loss: tensor([2.7128], device='cuda:0') tensor([0.0212], device='cuda:0')


epoch: 90 loss discriminator: 0.00135 loss generator: 5.14628: 100%|███████████████| 3500/3500 [47:17<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 90 with G and D val loss: tensor([2.7522], device='cuda:0') tensor([0.0208], device='cuda:0')


epoch: 91 loss discriminator: 0.00111 loss generator: 3.12510: 100%|███████████████| 3500/3500 [47:35<00:00,  1.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 91 with G and D val loss: tensor([2.7203], device='cuda:0') tensor([0.0211], device='cuda:0')


epoch: 92 loss discriminator: 0.00116 loss generator: 2.66758: 100%|███████████████| 3500/3500 [47:43<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 92 with G and D val loss: tensor([2.7370], device='cuda:0') tensor([0.0196], device='cuda:0')


epoch: 93 loss discriminator: 0.04936 loss generator: 1.76056: 100%|███████████████| 3500/3500 [47:47<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:55<00:00,  1.86s/it]


epoch 93 with G and D val loss: tensor([2.7216], device='cuda:0') tensor([0.0201], device='cuda:0')


epoch: 94 loss discriminator: 0.00193 loss generator: 3.70896: 100%|███████████████| 3500/3500 [47:43<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 94 with G and D val loss: tensor([2.6958], device='cuda:0') tensor([0.0219], device='cuda:0')


epoch: 95 loss discriminator: 0.00348 loss generator: 3.35286: 100%|███████████████| 3500/3500 [47:42<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:54<00:00,  1.86s/it]


epoch 95 with G and D val loss: tensor([2.7164], device='cuda:0') tensor([0.0218], device='cuda:0')


epoch: 96 loss discriminator: 0.00374 loss generator: 2.95530: 100%|███████████████| 3500/3500 [47:42<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:53<00:00,  1.85s/it]


epoch 96 with G and D val loss: tensor([2.7466], device='cuda:0') tensor([0.0216], device='cuda:0')


epoch: 97 loss discriminator: 0.00156 loss generator: 3.85497: 100%|███████████████| 3500/3500 [47:41<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:54<00:00,  1.85s/it]


epoch 97 with G and D val loss: tensor([2.6875], device='cuda:0') tensor([0.0214], device='cuda:0')


epoch: 98 loss discriminator: 0.00525 loss generator: 11.30066: 100%|██████████████| 3500/3500 [47:42<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 98 with G and D val loss: tensor([2.7215], device='cuda:0') tensor([0.0209], device='cuda:0')
Epoch    25: reducing learning rate of group 0 to 3.0000e-07.


epoch: 99 loss discriminator: 0.01572 loss generator: 2.85177: 100%|███████████████| 3500/3500 [47:42<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:51<00:00,  1.84s/it]


epoch 99 with G and D val loss: tensor([2.6768], device='cuda:0') tensor([0.0229], device='cuda:0')


epoch: 100 loss discriminator: 0.00547 loss generator: 1.64480: 100%|██████████████| 3500/3500 [47:48<00:00,  1.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 191/191 [05:55<00:00,  1.86s/it]


epoch 100 with G and D val loss: tensor([2.6952], device='cuda:0') tensor([0.0228], device='cuda:0')


epoch: 101 loss discriminator: 0.00277 loss generator: 2.93362:  69%|█████████▋    | 2413/3500 [33:06<14:54,  1.21it/s]


KeyboardInterrupt: 

In [15]:
def sample_images(batches_done, N):
    """ Save the first N generated images from the dataset """
    generator.eval()
    generator.to("cpu")
    
    imgs = []
    for i, batch in enumerate(tqdm(valloader)):
        if i >= N:
            break
        t = []
        real_A = Variable(batch[0].type(Tensor))
        real_B = Variable(batch[1].type(Tensor))
        t.append(real_A.detach())
        
        _,_,w,h = real_A.shape
        real_A = transforms.functional.pad(real_A, [0, 0, int((int(h/kernel_size)+1)*kernel_size)-h, int((int(w/kernel_size)+1)*kernel_size)-w])
        real_A = unfold(real_A).squeeze(0).T.view(-1,3,kernel_size,kernel_size)
        fake_B = generator(real_A)
        
        new_w = int((int(w/kernel_size)+1)*kernel_size)
        new_h = int((int(h/kernel_size)+1)*kernel_size)
        fake_B = fake_B.permute(1,2,3,0)
        fake_B = fake_B.reshape(1,3*kernel_size*kernel_size,fake_B.shape[3])
        fold = nn.Fold((new_w,new_h), (kernel_size,kernel_size), stride=(kernel_size,kernel_size))
        fake_B = fold(fake_B)
        
        fake_B = fake_B[:, :, 0:w, 0:h]
        t.append(fake_B.detach())
        t.append(real_B.detach())
        imgs.append(tuple(t))
    
    os.mkdir(f"images/watermarkdataset/{batches_done}")
    for j, (a,fb,b) in enumerate(imgs):
        img_sample = torch.cat((a, fb, b), -2)
        save_image(img_sample, f"images/watermarkdataset/{batches_done}/{j}.jpg")
    
    generator.to(device)
    generator.train()
    return imgs