# Import Libraries

In [33]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
from torchvision.utils import make_grid, save_image

from tqdm import tqdm

import os

import wandb

# Make Required Directoies

In [12]:
os.makedirs("CIFAR10", exist_ok=True)
os.makedirs("images", exist_ok=True)
os.makedirs("models", exist_ok=True)
os.makedirs("wandb", exist_ok=True)

# Set WANDB Environment

In [13]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_API_KEY"] = "1d6bdaf3f9f088abf0915e5e5cb6689e4c7e7476"

# Network Architecture

## Generator Architecture

In [14]:
class Generator(nn.Module):
    def __init__(self, n_classes, latent_dimension):
        super(Generator, self).__init__()

        self.n_classes = n_classes
        self.latent_dimension = latent_dimension

        self.label_embedding = nn.Embedding(n_classes, latent_dimension)

        self.fc1 = nn.Sequential(
            nn.Linear(in_features=110, out_features=384, ),
            nn.ReLU(inplace=True),
        )

        self.tconv2 = nn.Sequential(
            nn.ConvTranspose2d(384, 192, 4, 1, 0, bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(True),
        )

        self.tconv3 = nn.Sequential(
            nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(True),
        )

        self.tconv4 = nn.Sequential(
            nn.ConvTranspose2d(96, 48, 4, 2, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(True),
        )

        self.tconv5 = nn.Sequential(
            nn.ConvTranspose2d(48, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input_noise: torch.Tensor, input_labels: torch.Tensor):
        assert input_noise.dtype == torch.float32

        generator_input = torch.mul(self.label_embedding(input_labels), input_noise)

        x = generator_input

        x = self.fc1(x)

        x = x.view((-1, 384, 1, 1,))

        x = self.tconv2(x)
        x = self.tconv3(x)
        x = self.tconv4(x)
        x = self.tconv5(x)

        output_batch = x

        return output_batch


## Discriminator Architecture

In [15]:
class Discriminator(nn.Module):
    def __init__(self, n_classes, latent_dimension):
        super(Discriminator, self).__init__()

        self.n_classes = n_classes
        self.latent_dimension = latent_dimension

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout(p=0.5, inplace=False),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )

        self.conv6 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )

        self.conv_blocks = nn.Sequential(
            self.conv1,
            self.conv2,
            self.conv3,
            self.conv4,
            self.conv5,
            self.conv6,
        )

        self.fc7_discriminator = nn.Sequential(
            nn.Linear(4 * 4 * 512, 1),
            nn.Sigmoid()
        )

        self.fc7_auxiliary = nn.Sequential(
            nn.Linear(4 * 4 * 512, n_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, batch_image):
        x = batch_image
        x = self.conv_blocks(x)
        x = x.view(x.shape[0], 4 * 4 * 512)
        predicted_validity_probability = self.fc7_discriminator(x)
        predicted_auxiliary_probability = self.fc7_auxiliary(x)

        return predicted_validity_probability, predicted_auxiliary_probability


# Weights Initialization

In [16]:
def weights_init_normal(module, weights_init_type="xavier"):
    classname = module.__class__.__name__
    if classname.find("Conv") != -1:
        if weights_init_type == "xavier":
            torch.nn.init.xavier_normal_(module.weight.data, )

        elif weights_init_type == "he":
            torch.nn.init.kaiming_normal_(module.weight.data, )

        else:
            torch.nn.init.normal_(module.weight.data, 0.0, 0.02)

    elif classname.find("Linear") != -1:
        if weights_init_type == "xavier":
            torch.nn.init.xavier_normal_(module.weight.data, )

        elif weights_init_type == "he":
            torch.nn.init.kaiming_normal_(module.weight.data, )

        else:
            torch.nn.init.normal_(module.weight.data, 0.0, 0.02)

        torch.nn.init.constant_(module.bias.data, 0.0)

    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(module.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(module.bias.data, 0.0)


# Generate Images

In [17]:
def log_generated_images_in_grids(table, epoch, generator, n_row=2):
    noises = torch.normal(0, 1, (n_row ** 2, hps["n_classes"], hps["latent_dimension"])).to(device)
    labels = torch.arange(0, hps["n_classes"], 1).repeat(n_row ** 2, 1).to(device)

    generated_images = generator(noises, labels).reshape(n_row ** 2, hps["n_classes"], 3, hps["image_size"], hps["image_size"])

    grids = [make_grid(generated_images[:, n, :, :, :], padding=1, nrow=n_row, normalize=True) for n in range(hps["n_classes"])]

    table.add_data(
        epoch,
        *[wandb.Image(grid) for grid in grids]
    )

## Check GPU Availability

Check whether cuda is available and based on this, device object is built that is used in for pytorch tensors computation.

In [18]:
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")

## Hyper-Parameter Setting
In this section, hyper-parameters that used in bert fine-tuning are defined. hyper-parameter optimization (HPO) will be done in the next parts.

In [19]:
hps = {
    "batch_size": 500,
    "epochs": 250,
    "image_size": 32,
    "noise_standard_deviation": 0.05,
    "workers": 4,
    "n_classes": 10,
    "latent_dimension": 110,
    "weights_init_type": "normal",
    "learning_rate": 0.0001,
    "beta1": 0.5,
    "beta2": 0.999
}

# Experiment Track Initialization

In [20]:
wandb.login()

EXPERIMENT_NUM = 1
run = wandb.init(
    project="Auxiliary-Classifier-GAN-PyTorch",
    name=f"experiment {EXPERIMENT_NUM}",
    config=hps)

hps = wandb.config

# Data

## Dataset Initialization

In [21]:
dataset = datasets.CIFAR10(
    "CIFAR10",
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.Resize(hps["image_size"], interpolation=InterpolationMode.BICUBIC, ),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=hps["batch_size"],
    shuffle=True,
    num_workers=2
)


Files already downloaded and verified


# Model

## Model Initialization

In [22]:
generator = Generator(n_classes=hps["n_classes"], latent_dimension=hps["latent_dimension"])
discriminator = Discriminator(n_classes=hps["n_classes"], latent_dimension=hps["latent_dimension"])

## Weight Setting

In [23]:
generator.apply(lambda module: weights_init_normal(module, weights_init_type=hps["weights_init_type"]))
discriminator.apply(lambda module: weights_init_normal(module, weights_init_type=hps["weights_init_type"]))

Discriminator(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.5, inplace=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (conv3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (conv4): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_s

## Optimizer Definition

In [24]:
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=hps["learning_rate"], betas=(hps["beta1"], hps["beta2"]))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=hps["learning_rate"], betas=(hps["beta1"], hps["beta2"]))

## Loss Criteria Definition

In [25]:
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

## CPU to GPU

In [26]:
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Training

In [27]:
batch_size = hps["batch_size"]
epochs = hps["epochs"]
batches = len(dataloader)

wandb.watch([generator, discriminator], log="all", log_freq=batches // 4)

REAL = torch.zeros(batch_size, 1, requires_grad=False).fill_(1.0).to(device, )
FAKE = torch.zeros(batch_size, 1, requires_grad=False).fill_(0.0).to(device, )

columns = ["Epoch", *[f"Label {i}" for i in range(hps["n_classes"])]]

for epoch in tqdm(range(epochs)):

    generator_losses = torch.zeros(batches).to(device=device, dtype=torch.float)
    discriminator_losses = torch.zeros(batches).to(device=device, dtype=torch.float)


    generator.train()
    discriminator.train()

    for batch, (batch_real_images, batch_real_labels) in enumerate(dataloader):
        batch_real_images, batch_real_labels = batch_real_images.to(device, dtype=torch.float), batch_real_labels.to(device, dtype=torch.long)

        # Generator Train
        generator_optimizer.zero_grad()

        # Sample batch_noise and labels as generator input
        # and Generate a batch of images
        batch_noise = torch.normal(0, 1, (batch_size, hps["latent_dimension"])).to(device)
        batch_fake_generated_labels = torch.randint(0, hps["n_classes"], (batch_size,)).to(device, dtype=torch.long)
        batch_fake_generated_images = generator(batch_noise, batch_fake_generated_labels)

        batch_predicted_validity_for_fake_images, batch_predicted_label_for_fake_images = discriminator(batch_fake_generated_images)

        source_loss = adversarial_loss(batch_predicted_validity_for_fake_images, REAL)
        class_loss = auxiliary_loss(batch_predicted_label_for_fake_images, batch_real_labels)

        generator_loss = source_loss + class_loss
        generator_losses[batch] = generator_loss

        generator_loss.backward()
        generator_optimizer.step()

        # Discriminator Train
        discriminator_optimizer.zero_grad()

        batch_predicted_validity_for_real_images, batch_predicted_label_for_real_images = discriminator(batch_real_images)

        source_loss_real = adversarial_loss(batch_predicted_validity_for_real_images, REAL)
        class_loss_real = auxiliary_loss(batch_predicted_label_for_real_images, batch_real_labels)
        discriminator_loss_real = source_loss_real + class_loss_real

        batch_predicted_validity_for_fake_images, batch_predicted_label_for_fake_images = discriminator(batch_fake_generated_images.detach())

        source_loss_fake = adversarial_loss(batch_predicted_validity_for_fake_images, FAKE)
        class_loss_fake = auxiliary_loss(batch_predicted_label_for_fake_images, batch_fake_generated_labels)
        discriminator_loss_fake = source_loss_fake + class_loss_fake

        discriminator_loss = 0.5 * (discriminator_loss_real + discriminator_loss_fake)
        discriminator_losses[batch] = discriminator_loss

        discriminator_loss.backward()
        discriminator_optimizer.step()

    epoch_discriminator_loss = discriminator_losses.sum().item()
    epoch_generator_loss = generator_losses.sum().item()

    generator.eval()
    discriminator.eval()

    # Log Losses
    wandb.log({
        f"Loss/Generator": epoch_generator_loss,
        f"Loss/Discriminator": epoch_discriminator_loss,
    }, step=epoch)

    # Generate Images Grid and Log
    generated_images_table = wandb.Table(columns=columns)
    log_generated_images_in_grids(generated_images_table, epoch, generator)
    wandb.log({"Generated Images Per Epoch": generated_images_table, }, step=epoch)

    # Save Model
    if epoch % 50 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': generator_optimizer.state_dict(),
            'loss': epoch_generator_loss,
        }, f'models/generator_epoch{epoch}.pth')

        # artifact = wandb.Artifact('generator', type='model')
        # artifact.add_file(f'models/generator_epoch{epoch}.pth')
        # run.log_artifact(artifact)

        torch.save({
            'epoch': epoch,
            'model_state_dict': discriminator.state_dict(),
            'optimizer_state_dict': discriminator_optimizer.state_dict(),
            'loss': epoch_discriminator_loss,
        }, f'models/discriminator_epoch{epoch}.pth')

        # artifact = wandb.Artifact('discriminator', type='model')
        # artifact.add_file(f'models/discriminator_epoch{epoch}.pth')
        # run.log_artifact(artifact)


# Save Final Model
torch.save({
    'model_state_dict': generator.state_dict(),
    'optimizer_state_dict': generator_optimizer.state_dict(),
}, f'models/generator.pth')

artifact = wandb.Artifact('generator', type='model')
artifact.add_file(f'models/generator.pth')
run.log_artifact(artifact)

torch.save({
    'model_state_dict': discriminator.state_dict(),
    'optimizer_state_dict': discriminator_optimizer.state_dict(),
}, f'models/discriminator.pth')

artifact = wandb.Artifact('discriminator', type='model')
artifact.add_file(f'models/discriminator.pth')
run.log_artifact(artifact)

 77%|███████▋  | 386/500 [3:14:26<57:21, 30.19s/it]  wandb: Network error (TransientError), entering retry loop.
 78%|███████▊  | 388/500 [3:15:28<57:20, 30.71s/it]wandb: Network error (TransientError), entering retry loop.
 78%|███████▊  | 390/500 [3:16:25<54:25, 29.69s/it]wandb: Network error (TransientError), entering retry loop.
 80%|████████  | 400/500 [3:22:01<52:47, 31.68s/it]  wandb: Network error (TransientError), entering retry loop.
 80%|████████  | 401/500 [3:22:32<51:34, 31.26s/it]wandb: Network error (TransientError), entering retry loop.
 82%|████████▏ | 408/500 [3:26:04<46:26, 30.29s/it]wandb: Network error (TransientError), entering retry loop.
 83%|████████▎ | 413/500 [3:28:32<42:39, 29.42s/it]wandb: Network error (TransientError), entering retry loop.
 91%|█████████▏| 457/500 [3:50:51<22:04, 30.81s/it]wandb: Network error (TransientError), entering retry loop.
wandb: Network error (TransientError), entering retry loop.
wandb: Network error (TransientError), entering 

<wandb.sdk.wandb_artifacts.Artifact at 0x7fcd7882d720>

In [30]:
wandb.finish()

# Generate Final Result

In [57]:
n_row = 20
noises = torch.normal(0, 1, (n_row, hps["n_classes"], hps["latent_dimension"])).to(device)
labels = torch.arange(0, hps["n_classes"], 1).repeat(n_row, 1).to(device)

generated_images = generator(noises, labels).reshape(n_row * hps["n_classes"], 3, hps["image_size"], hps["image_size"])

# grids = [make_grid(generated_images[:, n, :, :, :], padding=1, nrow=n_row, normalize=True) for n in range(hps["n_classes"])]
grids = make_grid(generated_images[:, :, :, :], padding=2, nrow=n_row, normalize=True)
save_image(grids, "images/final_result.png")

In [56]:
%wandb USERNAME/PROJECT