In [None]:
import os

# Number of GPUs to train on.
# os.environ['CUDA_VISIBLE_DEVICES'] = '1, 2, 3'

In [None]:
import torch
import torch.nn as nn
import torch.cuda as cuda
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vutils

import numpy as np
from PIL import Image
import imageio
import pickle
import matplotlib.pyplot as plt
%matplotlib inline

from discriminator import Discriminator
from generator import Generator
from trainer import Trainer
from config import Config
import utils

In [None]:
# Set yo seed.
SEED = 42069

# Set NumPy seed.
np.random.seed(SEED)

# Set PyTorch seed.
torch.manual_seed(SEED)
cuda.manual_seed_all(SEED)

In [None]:
# Create a config object.
config = Config()

In [None]:
# Load the Generator and Discriminator into memory and push them onto the GPU(s) if told. 
# If the config flag contains a positive non-zero starting epoch, it will load the models
# checkpointed at that (epoch - 1) [epochs start from 0].

generator = Generator(z_dim=config.z_dim, num_classes=config.num_classes, 
                      base_width=config.base_width, 
                      base_filters=config.base_filters, 
                      use_attention=config.use_attention)
    
if config.pretrained:
    generator.load_state_dict(torch.load(config.checkpoint_path 
                                         + 'models/generator_{}.pth'
                                         .format(config.start_epoch - 1)))
    
generator = generator.to(config.device)
    
discriminator = Discriminator(config.num_classes, 
                              base_filters=config.base_filters, 
                              use_attention=config.use_attention, 
                              use_dropout=config.use_dropout)

if config.pretrained:
    discriminator.load_state_dict(torch.load(config.checkpoint_path 
                                         + 'models/discriminator_{}.pth'
                                         .format(config.start_epoch - 1)))

discriminator = discriminator.to(config.device)

if config.data_parallel:
    generator = nn.DataParallel(generator)
    discriminator = nn.DataParallel(discriminator)

In [None]:
# Get the dataloaders.
train_dataloader, test_dataloader = utils.get_dataloaders(config.train_root,
                                                          config.test_root,
                                                          batch_size=config.batch_size)

In [None]:
# Create a trainer object.
trainer = Trainer(config, train_dataloader, generator=generator, discriminator=discriminator)

In [None]:
# Commence the training.
trainer.train()

In [None]:
# Toggle this depending on whether the training was stopped or whether it naturally finished.
ABORTED = True

if ABORTED:
    # Haxx. This is a workaround to prevent the method from using the current epoch in the file names.
    trainer.current_epoch -= 1

# Dump the metrics to pickle files.
trainer.dump_metrics()

In [None]:
# Generate, save, and display the loss plots.
utils.plot_losses(
    'losses.png',
    g_loss=trainer.g_loss,
    d_loss=trainer.d_loss,
    superimpose=False
)

In [None]:
# Create a GIF of the samples.
# TODO: GIFs take too much space. Maybe ffmpeg to movie and convert to gif of low quality?
utils.create_interpolation(filename=config.checkpoint_path + 'interpolation.gif',
                           im_path=config.checkpoint_path + 'samples')