In [None]:
%matplotlib inline
import os
import time
import shutil
import cv2
import imageio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("dark")
plt.style.use("dark_background")

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets
import torchvision.transforms as transforms

from datetime import datetime
from IPython import display

import utils
import model
import calculate_anim_mask

training_timestamp = str(int(time.time()))
model_dir = f'trained_models/model_{training_timestamp}/'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device("cuda:0")

In [None]:
def print_and_log(text):
    print(text)
    print(text, file=open(f'{model_dir}/log.txt', 'a'))

In [None]:
shutil.copy2('./animation_gan.ipynb', model_dir)

In [None]:
if os.path.isfile('preprocessed_data/processed_data.pt'):
    processed_data = torch.load('preprocessed_data/processed_data.pt')
else:
    processed_data = utils.preprocess_data(batch_size=64, data_size=40_000)
    torch.save(processed_data, 'preprocessed_data/processed_data.pt')

In [None]:
batch_size = 8

dataset = torchvision.datasets.ImageFolder(root="data", transform=transforms.Compose([transforms.ToTensor(),
                                                                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
real_images, _ = next(iter(dataloader))
cur_img = real_images[0].numpy().transpose(1, 2, 0)
plt.figure(figsize=(9*2, 4*2))
plt.imshow((cur_img+1.0)*0.5)

In [None]:
img_i = 0
frame_i = 0 # 0-35
cur_img_cube = utils.get_img_cube(real_images.numpy().transpose(0, 2, 3, 1), img_i)
plt.imshow((cur_img_cube[:,:,frame_i*3:(frame_i+1)*3]+1.0)*0.5)

In [None]:
anim_file = utils.animate_img_cube(cur_img_cube, f"{model_dir}/real.gif")
display.Image(filename=anim_file)

In [None]:
anim_file = utils.animate_img_batch(real_images.numpy().transpose(0, 2, 3, 1), f"{model_dir}/real_4.gif")
display.Image(filename=anim_file)

In [None]:
latent_size = 100
# generator_net = model.Generator2D(latent_size=latent_size, input_channels=36*3, feature_map_size=64).to(device)
generator_net = model.Generator2D_10().to(device)
# generator_net = model.Generator3D().to(device)

# discriminator_net = model.Discriminator2D().to(device)
discriminator_net = model.Discriminator2D_10().to(device)
# discriminator_net = model.DiscriminatorTemporal().to(device)
# discriminator_net = model.Discriminator2D(input_channels=36*3, output_channels=36, feature_map_size=36*4, groups=36).to(device)
# discriminator_net = model.DiscriminatorSheet(input_channels=3, output_channels=1, feature_map_size=4).to(device)
# discriminator_net = model.Discriminator3D().to(device)

generator_net.apply(model.weights_init)
discriminator_net.apply(model.weights_init)

print_and_log(generator_net)
print_and_log(discriminator_net)

In [None]:
learning_rate = 0.0002
beta1 = 0.5
criterion = nn.BCELoss()

fixed_noise2d = torch.randn(4, latent_size, 1, 1, device=device)
# fixed_noise3d = torch.randn(4, latent_size, 1, 1, 1, device=device)

real_label = torch.full((batch_size,), 1, device=device)
fake_label = torch.full((batch_size,), 0, device=device)

discriminator_optimizer = optim.Adam(discriminator_net.parameters(), lr=learning_rate, betas=(beta1, 0.999))
generator_optimizer = optim.Adam(generator_net.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
# 2d
anim_mask = calculate_anim_mask.calculate_anim_mask(batch_size).to(device)
# 3d
# anim_mask = calculate_anim_mask.calculate_anim_mask(batch_size).reshape(-1, 3, 36, 64, 64).to(device)

In [None]:
num_epochs = 40

generator_losses = []
discriminator_losses = []
generated_img_cubes = None
train_iters = 0
torch_ones = torch.tensor([1.0]).to(device)

print_and_log(f"{datetime.now()} Starting Training")
for epoch in range(num_epochs):
    for batch_index, (real_images, _) in enumerate(dataloader, 0):
        
        discriminator_net.zero_grad()
        
        # 3d data
        # real_inputs = utils.convert_to_img_cube(real_images, batch_size).to(device)
        # 2d data
        real_inputs = utils.convert_to_img_plane(real_images, batch_size).to(device)
        # spatio-temporal data
        # real_inputs = utils.convert_to_img_plane(real_images, batch_size).reshape(-1, 3, 64, 64).to(device)
        
        output = discriminator_net(real_inputs).view(-1)
        discriminator_loss_real = criterion(output, real_label)
        
        discriminator_loss_real.backward()
        real_output = output.mean().item()
        
        noise2d = torch.randn(batch_size, latent_size, 1, 1, device=device)
        # noise3d = torch.randn(batch_size, latent_size, 1, 1, 1, device=device)

        generated_inputs = generator_net(noise2d)
        # generated_inputs = generator_net(noise3d)
        generated_inputs = torch.where(anim_mask, torch_ones, generated_inputs)
        
        # discriminator:3D
        # output = discriminator_net(generated_inputs.detach().reshape(-1, 3, 36, 64, 64)).view(-1)
        # discriminator:2D
        output = discriminator_net(generated_inputs.detach().reshape(-1, 108, 64, 64)).view(-1)
        # discriminator:spatio-temporal
        # output = discriminator_net(generated_inputs.detach().reshape(-1, 3, 64, 64)).view(-1)
        
        discriminator_loss_fake = criterion(output, fake_label)

        discriminator_loss_fake.backward()
        fake_output1 = output.mean().item()

        discriminator_loss = discriminator_loss_real + discriminator_loss_fake

        discriminator_optimizer.step()
        
        generator_net.zero_grad()
        
        # discriminator:3D
        # output = discriminator_net(generated_inputs.reshape(-1, 3, 36, 64, 64)).view(-1)
        # discriminator:2D
        output = discriminator_net(generated_inputs.reshape(-1, 108, 64, 64)).view(-1)
        # discriminator:spatio-temporal
        # output = discriminator_net(generated_inputs.reshape(-1, 3, 64, 64)).view(-1)
        
        generator_loss = criterion(output, real_label)
        
        generator_loss.backward()
        fake_output2 = output.mean().item()
        
        generator_optimizer.step()
        
        if batch_index % 250 == 0:
            print_and_log(f"{datetime.now()} [{epoch:02d}/{num_epochs}][{batch_index:04d}/{len(dataloader)}]\t"
                          f"D_Loss:{discriminator_loss.item():.4f} G_Loss:{generator_loss.item():.4f} Real:{real_output:.4f} "
                          f"Fake1:{fake_output1:.4f} Fake2:{fake_output2:.4f}")
        
        generator_losses.append(generator_loss.item())
        discriminator_losses.append(discriminator_loss.item())
        
        if (train_iters % 2500 == 0) or ((epoch == num_epochs-1) and (batch_index == len(dataloader)-1)):
            with torch.no_grad():
                generated_inputs = generator_net(fixed_noise2d).detach()
                # generated_inputs = generator_net(fixed_noise3d).detach()
                generated_inputs = torch.where(anim_mask[:4], torch_ones, generated_inputs)
                # 2d
                generated_inputs = generated_inputs.unsqueeze(0).cpu().numpy().transpose(0, 1, 3, 4, 2) 
                # 3d
                # generated_inputs = generated_inputs.reshape(-1, 108, 64, 64).unsqueeze(0).cpu().numpy().transpose(0, 1, 3, 4, 2) 
                generated_img_cubes = generated_inputs if generated_img_cubes is None else np.concatenate((generated_img_cubes, generated_inputs), axis=0)

        train_iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(generator_losses,label="Generator")
plt.plot(discriminator_losses,label="Discriminator")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"{model_dir}/loss.png")
plt.show()

In [None]:
torch.save(discriminator_net.state_dict(), f"{model_dir}/net_discriminator.pth")
torch.save(generator_net.state_dict(), f"{model_dir}/net_generator.pth")
np.save(f"{model_dir}/losses_generator.npy" , np.array(generator_losses))
np.save(f"{model_dir}/losses_discriminator.npy" , np.array(discriminator_losses))
np.save(f"{model_dir}/generated_img_cubes.npy" , generated_img_cubes)

In [None]:
img_i = 0
frame_i = 0 # 0-35
eval_img_cube = generated_img_cubes[-1][img_i]
plt.imsave(f"{model_dir}/sample_output.png", cv2.resize((eval_img_cube[:,:,frame_i*3:(frame_i+1)*3]+1.0)*0.5, dsize=(0,0), fx=4.0, fy=4.0, interpolation=cv2.INTER_NEAREST))
plt.imshow((eval_img_cube[:,:,frame_i*3:(frame_i+1)*3]+1.0)*0.5)

In [None]:
anim_file = utils.animate_img_cube(eval_img_cube, f"{model_dir}/generated.gif")
display.Image(filename=anim_file)

In [None]:
anim_file = utils.animate_img_batch(generated_img_cubes[-1], f"{model_dir}/generated_4.gif", get_img_cubes=False)
display.Image(filename=anim_file)

In [None]:
anim_file = utils.animate_img_cube(generated_img_cubes, f"{model_dir}/training.gif", training_outputs=True)
display.Image(filename=anim_file)

In [None]:
anim_file = utils.animate_img_batch(generated_img_cubes, f"{model_dir}/training_4.gif", get_img_cubes=False, training_outputs=True)
display.Image(filename=anim_file)

In [None]:
plt.figure(figsize=(9*2, 5*2))
plt.imshow((utils.get_sample_frames(real_images[0])+1.0)*0.5)

In [None]:
plt.figure(figsize=(9*2, 5*2))
plt.imshow(utils.get_sample_frames(np.array([cv2.resize(frame, dsize=(0,0), fx=0.25, fy=0.25) for frame in imageio.get_reader(f"outputs/generated.gif")]), True))

In [None]:
plt.figure(figsize=(9*2, 5*2))
plt.imshow(utils.get_sample_frames(np.array([cv2.resize(frame, dsize=(0,0), fx=0.5, fy=0.5) for frame in imageio.get_reader(f"outputs/generated_4.gif")]), True, True))