In [None]:
# adapted from https://github.com/facebookresearch/pytorch_GAN_zoo

%matplotlib inline
import os
import time
import shutil
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("dark")
plt.style.use("dark_background")

import torch
import torch.optim as optim

from skimage.transform import resize

from datetime import datetime
from IPython import display

import utils
import pgan_model
import batch_iterator

training_timestamp = str(int(time.time()))
model_dir = f'trained_models/model_{training_timestamp}/'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device("cuda:0")

In [None]:
def print_and_log(text):
    print(text)
    print(text, file=open(f'{model_dir}/log.txt', 'a'))

In [None]:
shutil.copy2('./pgan_train.ipynb', model_dir)

In [None]:
if os.path.isfile('preprocessed_data/processed_data.pt'):
    processed_data = torch.load('preprocessed_data/processed_data.pt')
else:
    processed_data = utils.preprocess_data(batch_size=64, data_size=40_000)
    torch.save(processed_data, 'preprocessed_data/processed_data.pt')

In [None]:
batch_size = 16
batch_iter = batch_iterator.BatchIterator(processed_data, batch_size)

In [None]:
batch_images = batch_iter.next_batch()
cur_img_plane = ((batch_images[0].type(torch.FloatTensor) / 255) - torch.Tensor([0.5])) / torch.Tensor([0.5])
plt.imshow((cur_img_plane[:3].permute(1, 2, 0)+1.0)*0.5)

In [None]:
img_sizes = {0: 4, 
             1: 8, 
             2: 16, 
             3: 32, 
             4: 64}

scale = 0

In [None]:
def resize_batch_s(batch_img, size):
    return np.array([resize(img_i.permute(1, 2, 0), (size, size)) for img_i in batch_img])

def resize_batch_l(batch_img, size, batch_size):
    return resize(batch_img.permute(0, 2, 3, 1), (batch_size, size, size))

In [None]:
img_i = 0
frame_i = 0 # 0-35
resized_batch = resize_batch_s(batch_images, img_sizes[scale])
plt.imshow(resized_batch[img_i][:,:,frame_i:frame_i+3])

In [None]:
cur_img_plane = (torch.from_numpy(resized_batch.transpose((0, 3, 1, 2))).type(torch.FloatTensor)[0] - torch.Tensor([0.5])) / torch.Tensor([0.5])
plt.imshow((cur_img_plane[:3].permute(1, 2, 0)+1.0)*0.5)

In [None]:
discriminator_net = pgan_model.PGANDiscriminator()
generator_net = pgan_model.PGANGenerator()

In [None]:
learning_rate = 0.001

discriminator_net.to(device)
generator_net.to(device)

optimizer_d = optim.Adam(filter(lambda p: p.requires_grad, discriminator_net.parameters()), betas=[0, 0.99], lr=learning_rate)
optimizer_g = optim.Adam(filter(lambda p: p.requires_grad, generator_net.parameters()), betas=[0, 0.99], lr=learning_rate)

optimizer_d.zero_grad()
optimizer_g.zero_grad()

print_and_log(generator_net)
print_and_log(discriminator_net)

In [None]:
print_and_log(f"{datetime.now()} Starting Training")

n_scales = 5
model_alpha = 0.0
alpha_update_cons = 0.00125
epoch_per_scale = 16

normalize_t = torch.Tensor([0.5]).to(device)

for scale in range(0, n_scales):
    if scale > 0:
        model_alpha = 1.0
        
    for batch_step in range(1, (epoch_per_scale*batch_iter.size//batch_size)+1):
        if batch_step % 25 == 0 and model_alpha > 0:
            model_alpha = model_alpha - alpha_update_cons
            model_alpha = 0.0 if model_alpha < 1e-5 else model_alpha
            
        if batch_step == 1 or batch_step % 100 == 0:
            print(f"{datetime.now()} scale:{scale}, step:{batch_step}, alpha:{model_alpha}")
            
        batch_images = batch_iter.next_batch()
        batch_images = resize_batch_s(batch_images, img_sizes[scale])
        batch_images = (torch.from_numpy(batch_images.transpose((0, 3, 1, 2))).type(torch.FloatTensor).to(device) - normalize_t) / normalize_t