In [None]:
# adapted from https://github.com/facebookresearch/pytorch_GAN_zoo

%matplotlib inline
import os
import time
import shutil
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("dark")
plt.style.use("dark_background")

import torch
import torch.optim as optim

from pytorch_GAN_zoo.models.loss_criterions import base_loss_criterions
from pytorch_GAN_zoo.models.loss_criterions.gradient_losses import WGANGPGradientPenalty
from pytorch_GAN_zoo.models.utils.utils import finiteCheck

from skimage.transform import resize

from datetime import datetime
from IPython import display

import utils
import pgan_model
import batch_iterator

training_timestamp = str(int(time.time()))
model_dir = f'trained_models/model_{training_timestamp}/'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device("cuda:0")

In [None]:
def print_and_log(text):
    print(text)
    print(text, file=open(f'{model_dir}/log.txt', 'a'))

In [None]:
shutil.copy2('./pgan_train.ipynb', model_dir)

In [None]:
processed_data = torch.load('preprocessed_data/processed_data_4x4.pt')

In [None]:
batch_size = 16
batch_iter = batch_iterator.BatchIterator(processed_data, batch_size)

In [None]:
batch_images = batch_iter.next_batch()
plt.imshow((batch_images[0][:3].permute(1, 2, 0)+1.0)*0.5)

In [None]:
img_sizes = {0: 4, 
             1: 8, 
             2: 16, 
             3: 32, 
             4: 64}

scale = 0

In [None]:
discriminator_net = pgan_model.PGANDiscriminator()
generator_net = pgan_model.PGANGenerator()

In [None]:
learning_rate = 0.001

discriminator_net.to(device)
generator_net.to(device)

optimizer_d = optim.Adam(filter(lambda p: p.requires_grad, discriminator_net.parameters()), betas=[0, 0.99], lr=learning_rate)
optimizer_g = optim.Adam(filter(lambda p: p.requires_grad, generator_net.parameters()), betas=[0, 0.99], lr=learning_rate)

optimizer_d.zero_grad()
optimizer_g.zero_grad()

print_and_log(generator_net)
print_and_log(discriminator_net)

In [None]:
loss_criterion = base_loss_criterions.WGANGP(device)

epsilon_d = 0.001

In [None]:
print_and_log(f"{datetime.now()} Starting Training")

n_scales = 5
model_alpha = 0.0
alpha_update_cons = 0.00125
epoch_per_scale = 16

show_scaled_img = False
normalize_t = torch.Tensor([0.5]).to(device)
scale_t = torch.Tensor([255]).to(device)

fixed_latent = torch.randn(4, 512).to(device)
generated_img_cubes = []

for scale in range(0, n_scales):
    if scale > 0:
        model_alpha = 1.0
        
    if scale == 1:
        processed_data = torch.load('preprocessed_data/processed_data_8x8.pt')
    if scale == 2:
        processed_data = torch.load('preprocessed_data/processed_data_16x16.pt')
    if scale == 3:
        processed_data = torch.load('preprocessed_data/processed_data_32x32.pt')
    if scale == 4:
        processed_data = torch.load('preprocessed_data/processed_data.pt')
    
    batch_iter = batch_iterator.BatchIterator(processed_data, batch_size)
    print_and_log(f"{datetime.now()} Starting scale:{scale}")
    
    if show_scaled_img:
        batch_images = batch_iter.next_batch()
        if scale < 3:
            plt.imshow((batch_images[0][:3].permute(1, 2, 0)+1.0)*0.5)
            plt.show()
        else:
            cur_img_plane = ((batch_images[0].type(torch.FloatTensor) / torch.Tensor([255])) - torch.Tensor([0.5])) / torch.Tensor([0.5])
            plt.imshow((cur_img_plane[:3].permute(1, 2, 0)+1.0)*0.5)
            plt.show()
        
    for batch_step in range(1, (epoch_per_scale*40000//batch_size)+1):
        if batch_step % 25 == 0 and model_alpha > 0:
            model_alpha = model_alpha - alpha_update_cons
            model_alpha = 0.0 if model_alpha < 1e-5 else model_alpha
            
        batch_images = batch_iter.next_batch()
        if scale >= 3:
            batch_images = ((batch_images.type(torch.FloatTensor).to(device) / scale_t) - normalize_t) / normalize_t
        else:
            batch_images = batch_images.to(device)
        
        discriminator_net.set_alpha(model_alpha)
        generator_net.set_alpha(model_alpha)
        
        optimizer_d.zero_grad()
        
        pred_real_d = discriminator_net(batch_images, False)
        
        loss_d = loss_criterion.getCriterion(pred_real_d, True)
        
        input_latent = torch.randn(batch_size, 512).to(device)
        
        pred_fake_g = generator_net(input_latent).detach()
        pred_fake_d = discriminator_net(pred_fake_g, False)
        
        loss_d_fake = loss_criterion.getCriterion(pred_fake_d, False)
        
        loss_d += loss_d_fake
        
        loss_d_grad = WGANGPGradientPenalty(batch_images, pred_fake_g, discriminator_net, weight=10.0, backward=True)
        
        loss_epsilon = (pred_real_d[:, 0] ** 2).sum() * epsilon_d
        loss_d += loss_epsilon
        
        loss_d.backward(retain_graph=True)
        finiteCheck(discriminator_net.parameters())
        optimizer_d.step()
        
        optimizer_d.zero_grad()
        optimizer_g.zero_grad()
        
        input_noise = torch.randn(batch_size, 512).to(device)
        
        pred_fake_g = generator_net(input_noise)
        
        pred_fake_d, phi_g_fake = discriminator_net(pred_fake_g, True)
        
        loss_g_fake = loss_criterion.getCriterion(pred_fake_d, True)
        loss_g_fake.backward(retain_graph=True)
        
        finiteCheck(generator_net.parameters())
        optimizer_g.step()
        
        if batch_step == 1 or batch_step % 100 == 0:
            print_and_log(f"{datetime.now()} [{scale}/{n_scales}][{batch_step:05d}/{epoch_per_scale*40000//batch_size}], Alpha:{model_alpha:.4f} "
                          f"Loss_G:{loss_g_fake.item():.4f}\tLoss_D:{loss_d.item():.4f}")
        
        if batch_step % 5000 == 1:
            with torch.no_grad():
                generated_inputs = generator_net(fixed_latent).detach()
                generated_img_cubes += [generated_inputs.cpu().numpy().transpose(0, 2, 3, 1)]
        
        break
    break

generated_img_cubes = np.array(generated_img_cubes)

In [None]:
img_i = 0
frame_i = 0 # 0-35
eval_img_cube = generated_img_cubes[-1][img_i]
plt.imshow((eval_img_cube[:,:,frame_i*3:(frame_i+1)*3]+1.0)*0.5)