In [None]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as torch_utils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
manualSeed = 123
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
data_dir = ""
model_save_path = ""
animation_save_path = ""
training_plot_save_path = ""

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 64
lr = 10e-4
beta1 = 0.5
batch_size = 64
noise_dim = 100
workers = 2
num_epochs = 50
real_label = 1.0
fake_label = 0.0
classes = 5
# Monitor Progress
progress = list()
fixed_noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
fixed_labels = torch.randint(0, classes, (batch_size, ), device=device)

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, classes) -> None:
        super(Discriminator, self).__init__()
        self.classes = classes

        self.embedding = nn.Sequential(
            nn.Embedding(classes, 64),
            nn.Linear(64, 64*64),
            nn.Unflatten(1, (1, 64, 64))
        )

        conv_1 = self.conv_block(4, 64)
        conv_2 = self.conv_block(64, 128)
        conv_3 = self.conv_block(128, 256)
        conv_4 = self.conv_block(256, 512)

        self.classifier = nn.Sequential(
            conv_1,
            conv_2,
            conv_3,
            conv_4,
            nn.Conv2d(512, 1, (5, 5), 2, 1),
            nn.Sigmoid(),
        )

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (5, 5), 2, 2),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
    
    def forward(self, x, label):
        label_embedding = self.embedding(label)
        comb_latent_vector = torch.concat((x, label_embedding), dim=1)
        output = self.classifier(comb_latent_vector)
        return output

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self, classes):
        super(Generator, self).__init__()
        self.classes = classes

        self.embedding = nn.Sequential(
            nn.Embedding(classes, 64),
            nn.Unflatten(1, (1, 8, 8))
        )

        self.latent_vector = nn.Sequential(
            nn.Linear(100, 4096),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (64, 8, 8))
        )

        upsample_1 = self.upsample_block(65, 256, 1)
        upsample_2 = self.upsample_block(256, 128, 1)
        upsample_3 = self.upsample_block(128, 64, 1)

        self.conv_model = nn.Sequential(
            upsample_1,
            upsample_2,
            upsample_3,
            nn.Conv2d(64, 3, (1, 1), 1, 0),
            nn.Tanh()
        )
    
    def upsample_block(self, in_channels, out_channels, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, (4, 4), 2, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x, label):
        latent_vector = self.latent_vector(x)
        label_embedding = self.embedding(label)
        comb_latent_vector = torch.concat((latent_vector, label_embedding), dim = 1)
        output = self.conv_model(comb_latent_vector)
        return output

In [None]:
# custom weights initialization
# Reference (PyTorch Tutorials)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find("Linear") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)


In [None]:
dataset = datasets.ImageFolder(root=data_dir,
                               transform=transforms.Compose([
                                    transforms.Resize(image_size),
                                    transforms.CenterCrop(image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))

dataloader = data.DataLoader(dataset, 
                             batch_size=batch_size,
                             shuffle=True, 
                             num_workers=workers)

In [None]:
disc_net = Discriminator(classes)
gen_net = Generator(classes)
disc_net.to(device)
gen_net.to(device)
disc_net.apply(weights_init)
gen_net.apply(weights_init)

In [None]:
criterion = nn.BCELoss()

disc_optimizer = optim.Adam(disc_net.parameters(), lr=lr, betas=(beta1, 0.999))
gen_optimizer = optim.Adam(gen_net.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):

        # Training the discriminator
        # 1.a: Train Discriminator on Real Images
        # 1.b: Train Generator on Fake Images
        disc_net.zero_grad()

        real_images = data[0].to(device)
        # Conditional Noise
        real_labels = data[1].to(device)
        num_images = real_images.size(0)
        label = torch.full((num_images,), real_label, dtype=torch.float, device=device)

        output = disc_net(real_images, real_labels).view(-1)

        disc_err_real = criterion(output, label)
        disc_err_real.backward()

        noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
        # Conditional Noise
        noise_labels = torch.randint(0, classes, (num_images, ), device=device)
        fake = gen_net(noise, noise_labels)
        label.fill_(fake_label)

        output = disc_net(fake.detach(), noise_labels).view(-1)

        disc_err_fake = criterion(output, label)
        disc_err_fake.backward()

        disc_err = disc_err_real + disc_err_fake
        disc_optimizer.step()

        # Training the Generator
        # Steps:
        # 1. Create Label Array all elements == 1
        # 2. Get Discriminator Predictions on Fake Images
        # 3. Calculate loss
        gen_net.zero_grad()
        label.fill_(real_label)
        output = disc_net(fake, noise_labels).view(-1)

        gen_err = criterion(output, label)
        gen_err.backward()

        gen_optimizer.step()

        # Training Update
        if i % 50 == 0:
            print(
                f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\tLoss_D: {disc_err.item()}\tLoss_G: {gen_err.item()}"
            )

        # Tracking loss
        G_losses.append(gen_err.item())
        D_losses.append(disc_err.item())

        # Tracking Generator Progress
        if (iters % 500 == 0) or (
            (epoch == num_epochs - 1) and (i == len(dataloader) - 1)
        ):
            with torch.no_grad():
                fake = gen_net(fixed_noise, fixed_labels).detach().cpu()
            progress.append(torch_utils.make_grid(fake, padding=2, normalize=True))

        iters += 1



In [None]:
# Save generator
torch.save(gen_net, model_save_path)

In [None]:
# Plot Training Graph
fig1 = plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(training_plot_save_path)
plt.show()

In [None]:
# Progress Animation
fig2 = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in progress]
anim = animation.ArtistAnimation(fig2, ims, interval=1000, repeat_delay=1000, blit=True)
writervideo = animation.FFMpegWriter(fps=60)
anim.save(animation_save_path, writer=writervideo)
plt.close()