# Import Libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

In [None]:
from model import Critic, Generator, init_weights
from utils import gradient_penalty, save_checkpoint, load_checkpoint

In [None]:
# Setting manual seed for reproducibility
manual_seed = 999
print(f"Random Seed: {manual_seed}")
torch.manual_seed(manual_seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Pytorch version: {torch.__version__}")
print(f"Device: {device}")

In [None]:
# Hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
IMAGE_SHAPE = (CHANNELS_IMG, IMAGE_SIZE, IMAGE_SIZE)
Z_DIM = 100
NUM_EPOCHS = 1
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

NGPU = 2

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [None]:
# Load dataset from data/chest_xray/train
dataset = datasets.ImageFolder(root="../data/chest_xray/train", transform=transform)

print(f"Dataset length: {len(dataset)}")
print(f"Dataset classes: {dataset.classes}")
print(f"Dataset class to idx mapping: {dataset.class_to_idx}")

print(f"Dataset sample: {dataset.samples[0]}")

plt.imshow(dataset[0][0].permute(1, 2, 0))
plt.show()

In [None]:
loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)

# Print sample batch
# for batch_idx, (data, targets) in enumerate(loader):
#     print(f"Batch idx: {batch_idx}")
#     print(f"Data shape: {data.shape}")
#     print(f"Targets shape: {targets.shape}")
#     print(f"Targets: {targets}")
#     plt.figure(figsize=(8, 8))
#     plt.axis("off")
#     plt.imshow(
#         np.transpose(
#             vutils.make_grid(data.to(device)[:64], padding=2, normalize=True).cpu(),
#             (1, 2, 0),
#         )
#     )
#     plt.show()
#     break


# Print sample batch
real_batch = next(iter(loader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),
        (1, 2, 0),
    )
)
plt.show()

In [None]:
generator: Generator = Generator(Z_DIM, IMAGE_SHAPE, FEATURES_GEN, NGPU).to(device)

if device == "cuda" and NGPU > 1:
    generator = nn.DataParallel(generator, list(range(NGPU)))

init_weights(generator)
print(generator)

In [None]:
critic: Critic = Critic(IMAGE_SHAPE, FEATURES_CRITIC, NGPU).to(device)

if device == "cuda" and NGPU > 1:
    critic = nn.DataParallel(critic, list(range(NGPU)))

init_weights(critic)
print(critic)

In [None]:
# Optimizer for generator
opt_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
print(opt_gen)

In [None]:
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
print(opt_critic)

In [None]:
# Tensorboard
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

In [None]:
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
plt.imshow(
    np.transpose(
        vutils.make_grid(
            generator(fixed_noise).detach().cpu()[:1], padding=2, normalize=True
        ),
        (1, 2, 0),
    )
)
plt.title("Fake images")
plt.show()

In [None]:
step = 0
generator_losses = []
critic_losses = []

In [None]:
generator.train()

In [None]:
critic.train()

In [None]:
%matplotlib inline
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        print(f"Epoch: {epoch}, Batch: {batch_idx}")
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # => min -E[critic(real)] + E[critic(fake)]
        critic_loss = 0
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = generator(noise)

            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)

            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            critic_loss += loss_critic.item()
        
        critic_losses.append(critic_loss / CRITIC_ITERATIONS)

        # Train Generator: min -E[critic(gen_fake)] <-> max E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        generator.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        generator_losses.append(loss_gen.item())

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = generator(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = vutils.make_grid(
                    real[:32], normalize=True, padding=2
                )
                img_grid_fake = vutils.make_grid(
                    fake[:32], normalize=True, padding=2
                )

                writer_real.add_image(
                    "Real images", img_grid_real, global_step=step
                )
                writer_fake.add_image("Fake images", img_grid_fake, global_step=step)

            step += 1
    
    # Print images after each epoch
    with torch.no_grad():
        sample_noise = torch.randn(64, Z_DIM, 1, 1).to(device)
        sample = generator(sample_noise)
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title(f"Generated images after epoch {epoch}")
        plt.imshow(
            np.transpose(
                vutils.make_grid(sample[:64], padding=2, normalize=True).cpu(),
                (1, 2, 0),
            )
        )
        plt.show()
    
    # Save model after each epoch
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), f"model_states/generator_epoch_{epoch}.pth")
        torch.save(critic.state_dict(), f"model_states/critic_epoch_{epoch}.pth")

# Save model after training
torch.save(generator.state_dict(), "generator.pth")
torch.save(critic.state_dict(), "critic.pth")

In [None]:
# Plot losses
plt.plot(generator_losses, label="Generator", alpha=0.5, color="green")
plt.plot(critic_losses, label="Critic", alpha=0.5, color="red")
plt.title("Losses")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()