Implement a decoder network - with input as generated images that outputs the input random variable in the GAN. Add a norm-based reconstruction loss between the input to the generator and the output of the decoder. Train it simultaneously along with regular GAN losses.

In [None]:
import torch
import os
from torch import nn
import torchvision
from torchvision import transforms
import torchvision.models as models
from torchvision.transforms import ToTensor, Compose, Normalize, Resize
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import random_split
from PIL import Image
from torch import nn
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm
from torchinfo import summary
import math
from torch.utils.tensorboard import SummaryWriter

In [None]:

dataroot = "data/Animals_data/animals/animals"
TRAIN_DCGAN = True
N_EPOCHS = 250
BATCH_SIZE = 64
N_critic = 1
z_dim = 100
Img_channels = 3
Input_Shape = (3, 128, 128)
Hidden_dims = 64
lr = 1e-4
betas = (0.5, 0.999)
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
NUM_WORKERS = 1  # Number of dataloader workers

In [None]:
# Tensorboard stuff

dcgan_writer = SummaryWriter(log_dir=f'DC_GAN/tensorboard')


def log_losses_to_tensorboard(epoch, g_loss, d_loss, dec_loss):
    dcgan_writer.add_scalar('Loss/Generator', g_loss, epoch)
    dcgan_writer.add_scalar('Loss/Discriminator', d_loss, epoch)
    dcgan_writer.add_scalar('Loss/Decoder', dec_loss, epoch)


def log_gradients_to_tensorboard(writer, model, epoch, model_name):
    total_norm = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            norm = param.grad.norm(2).item()
            total_norm += norm ** 2
            writer.add_scalar(f'Gradients/{model_name}/{name}', norm, epoch)
    total_norm = total_norm ** 0.5
    writer.add_scalar(f'Gradients/{model_name}/total_norm', total_norm, epoch)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, Input_channels=Img_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=Input_channels, out_channels=Hidden_dims,
                      kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(Hidden_dims),
            nn.Conv2d(in_channels=Hidden_dims, out_channels=Hidden_dims,
                      kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(Hidden_dims),
            nn.Conv2d(in_channels=Hidden_dims, out_channels=2*Hidden_dims,
                      kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(2*Hidden_dims),
            nn.Conv2d(in_channels=2*Hidden_dims, out_channels=4 *
                      Hidden_dims, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(4*Hidden_dims),
            nn.Conv2d(in_channels=4*Hidden_dims, out_channels=8 *
                      Hidden_dims, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(8*Hidden_dims),
            nn.Conv2d(in_channels=8*Hidden_dims, out_channels=1,
                      kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid())

    def forward(self, x):
        return self.main(x)

In [None]:
# generator architecture
class Generator(nn.Module):
    def __init__(self, z=z_dim, Output_channels=Img_channels):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels=z_dim, out_channels=4*Hidden_dims,
                               # [4x4]
                               kernel_size=7, stride=1, padding=1, bias=False),
            nn.LayerNorm([4*Hidden_dims, 5, 5]),
            nn.LeakyReLU(0.2, True),

            nn.UpsamplingNearest2d(
                scale_factor=2),  # [8x8]
            nn.Conv2d(in_channels=4*Hidden_dims, out_channels=2 * \
                      Hidden_dims, kernel_size=7, stride=1, padding=2),
            nn.LayerNorm([2*Hidden_dims, 8, 8]),
            nn.LeakyReLU(0.2, True),

            nn.UpsamplingNearest2d(
                scale_factor=2),  # [16x16]
            nn.Conv2d(in_channels=2*Hidden_dims, out_channels=Hidden_dims,
                      kernel_size=5, stride=1, padding=2),
            nn.LayerNorm([Hidden_dims, 16, 16]),
            nn.LeakyReLU(0.2, True),

            nn.UpsamplingNearest2d(
                scale_factor=2),  # [32x32]
            nn.Conv2d(in_channels=Hidden_dims, out_channels=Hidden_dims,
                      kernel_size=5, stride=1, padding=2),
            nn.LayerNorm([Hidden_dims, 32, 32]),
            nn.LeakyReLU(0.2, True),

            nn.UpsamplingNearest2d(
                scale_factor=2),  # [64x64]
            nn.Conv2d(in_channels=Hidden_dims, out_channels=Hidden_dims,
                      kernel_size=3, stride=1, padding=1),
            nn.LayerNorm([Hidden_dims, 64, 64]),
            nn.LeakyReLU(0.5, True),

            nn.UpsamplingNearest2d(
                scale_factor=2),  # [128x128]
            nn.Conv2d(in_channels=Hidden_dims, out_channels=Output_channels,
                      kernel_size=3, stride=1, padding=1),
            nn.Tanh())

    def forward(self, x):
        return self.main(x)

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)
weights_init(generator)
weights_init(discriminator)

In [None]:
summary(discriminator,input_size=(1, 3,128,128))

In [None]:
summary(generator,input_size=(1, 100,1,1))

In [None]:
def sample_noise(size=z_dim, batch_size=BATCH_SIZE):
    return torch.randn(batch_size, size, 1, 1).to(device)

In [None]:
opt_gen = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=betas)
opt_disc = torch.optim.Adam(
    discriminator.parameters(), lr=0.00012, betas=betas)
scheduler_gen = torch.optim.lr_scheduler.StepLR(
    opt_gen, step_size=N_EPOCHS//4, gamma=0.5)
scheduler_disc = torch.optim.lr_scheduler.StepLR(
    opt_disc, step_size=N_EPOCHS//4, gamma=0.5)

In [None]:
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
img_dataset = torchvision.datasets.ImageFolder(
    root=dataroot, transform=transform)
dataloader = DataLoader(img_dataset, num_workers=NUM_WORKERS,
                        batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class Decoder(nn.Module):
    def __init__(self, img_channels=3, z_dim=z_dim):
        super(Decoder, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=Hidden_dims,
                      kernel_size=3, stride=1, padding=1),
            nn.LayerNorm([Hidden_dims, 128, 128]),
            nn.LeakyReLU(0.2, True),

            nn.AvgPool2d(kernel_size=2),  # [64x64]
            nn.Conv2d(in_channels=Hidden_dims, out_channels=Hidden_dims,
                      kernel_size=3, stride=1, padding=1),  # Keep spatial size
            nn.LayerNorm([Hidden_dims, 64, 64]),
            nn.LeakyReLU(0.2, True),

            nn.AvgPool2d(kernel_size=2),  # [32x32]
            nn.Conv2d(in_channels=Hidden_dims, out_channels=Hidden_dims,
                      kernel_size=5, stride=1, padding=2),  # Keep spatial size
            nn.LayerNorm([Hidden_dims, 32, 32]),
            nn.LeakyReLU(0.2, True),

            nn.AvgPool2d(kernel_size=2),  # [16x16]
            nn.Conv2d(in_channels=Hidden_dims, out_channels=2*Hidden_dims,
                      kernel_size=5, stride=1, padding=2),  # Keep spatial size
            nn.LayerNorm([2*Hidden_dims, 16, 16]),
            nn.LeakyReLU(0.2, True),

            nn.AvgPool2d(kernel_size=2),  # [8x8]
            nn.Conv2d(in_channels=2*Hidden_dims, out_channels=4*Hidden_dims,
                      kernel_size=5, stride=1, padding=2),  # Keep spatial size
            nn.LayerNorm([4*Hidden_dims, 8, 8]),
            nn.LeakyReLU(0.2, True),

            nn.AvgPool2d(kernel_size=2),  # [4x4]
            nn.Conv2d(in_channels=4*Hidden_dims, out_channels=z_dim,
                      kernel_size=4, stride=1, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [None]:
decoder = Decoder(z_dim=z_dim).to(device)
weights_init(decoder)
summary(decoder,input_size=(1, 3,128,128))

For our decoder loss we use a combination:
1. A L2 norm-based reconstruction loss between Z and Decoder(Generator(Z))
2. A L1 perceptual loss between real images and Generator(Decoder(real images)). We prefer L1 over MSE to reduce sensitivty to outliers when comparing pixel values.

In [None]:
opt_dec = torch.optim.Adam(
    decoder.parameters(), lr=0.0001, weight_decay=0.0001, betas=betas)
scheduler_dec = torch.optim.lr_scheduler.StepLR(
    opt_dec, step_size=N_EPOCHS//4, gamma=0.5)
criterion_recon = nn.MSELoss()  # Norm-based reconstruction loss
perceptual_loss = nn.L1Loss() 

In [None]:
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# Code take from: https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py

def DiffAugment(x, policy='color,translation,cutout', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1,
                                   dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1,
                                   dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio +
                           0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1,
                                  size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1,
                                  size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = torch.nn.functional.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[
        grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(
        2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(
        3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x -
                         cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y -
                         cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3),
                      dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

In [None]:
# training loop

def train_dcgan():
    step = 0
    EPS = 1e-9
    for epoch in trange(N_EPOCHS):
        for (X, _) in dataloader:
            X = X.to(device)
            curr_batch_size = X.shape[0]
            # X = X + torch.rand(size = X.shape )/5

            # training discriminator
            for _ in range(N_critic):
                # real images
                opt_disc.zero_grad()
                d_real = discriminator(DiffAugment(X))
                loss_d_real = -1*torch.log(d_real+EPS).mean()
                loss_d_real.backward(retain_graph=True)
                # fake images
                noise = sample_noise(batch_size=curr_batch_size)
                fake = generator(noise)
                d_fake = discriminator(DiffAugment(fake).detach())
                loss_d_fake = -1*torch.log(1-d_fake+EPS).mean()
                loss_d_fake.backward(retain_graph=True)

                opt_disc.step()

            # training generator
            opt_gen.zero_grad()

            g_fake = discriminator(DiffAugment(fake))
            loss_g = (-1*torch.log(g_fake+EPS)).mean()

            opt_dec.zero_grad()
            reconstructed_noise = decoder(DiffAugment(fake))
            loss_recon = criterion_recon(reconstructed_noise, noise)
            aug_real = DiffAugment(X)
            real_image_noise = decoder(aug_real)
            image_recon_loss = perceptual_loss(
                generator(real_image_noise), aug_real)

            total_loss = loss_g + loss_recon + image_recon_loss
            total_loss.backward(retain_graph=True)

            opt_dec.step()
            opt_gen.step()

            if step % 50 == 0:
                print(f"epoch:{
                    epoch+1} iter{step} disc_loss:{(loss_d_fake+loss_d_real)} gen_loss:{loss_g} dec_loss:{(loss_recon.item())}")

            if step % 200 == 0:
                with torch.no_grad():
                    # Use the first image from fake_images instead of generating new ones
                    # Take the first image and add batch dimension
                    fake_images = fake[0].unsqueeze(0).cpu()
                    # Ensure the image is in the range [0, 1]
                    fake = (fake + 1) / 2.0  # Transform from [-1, 1] to [0, 1]
                    img = torchvision.utils.make_grid(fake, normalize=False)
                    img_np = img.cpu().detach().permute(1, 2, 0).numpy()  # Add detach() here
                    plt.figure(figsize=(8, 8))
                    plt.imshow(img_np)
                    plt.axis('off')
                    plt.title(f"Epoch {epoch+1}, iter {step}")
                    save_dir = "DC_GAN/generated_images"
                    os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f'generated_image_epoch_{
                        epoch+1}_batch_{step+1}.png')
                    plt.savefig(save_path)
                    plt.close()
            step += 1

        scheduler_dec.step()
        scheduler_disc.step()
        scheduler_gen.step()

        log_losses_to_tensorboard(epoch, loss_g.item(
        ), loss_d_fake.item()+loss_d_real.item(), loss_recon.item())
        log_gradients_to_tensorboard(
            dcgan_writer, generator, epoch, 'Generator')
        log_gradients_to_tensorboard(
            dcgan_writer, discriminator, epoch, 'Discriminator')
        log_gradients_to_tensorboard(dcgan_writer, decoder, epoch, 'Decoder')


save_dir = "DC_GAN/models"
os.makedirs(save_dir, exist_ok=True)
save_path_g = os.path.join(save_dir, "generator")
save_path_d = os.path.join(save_dir, "discriminator")
save_path_dec = os.path.join(save_dir, "decoder")


if TRAIN_DCGAN:
    train_dcgan()
    torch.save(generator.state_dict(), save_path_g)
    torch.save(discriminator.state_dict(), save_path_d)
    torch.save(decoder.state_dict(), save_path_dec)
else:
    decoder.load_state_dict(torch.load(save_path_dec, weights_only=True))
    decoder.eval()
    generator.load_state_dict(torch.load(save_path_g, weights_only=True))
    generator.eval()

Obtain the decoder output (trained in the previous step) for all the input
images. Train an MLP to solve a classification task by taking these decoded vectors as input. Compute and report the classification accuracy
and the F1 score.

Generate a t-SNE plot of the decoded latents for all real images to check if they are separable by class.

In [None]:
from sklearn.manifold import TSNE

z_vectors = []
class_labels = []

with torch.no_grad():  # No need to compute gradients for this
    for images, labels in dataloader:
        images = images.to(device)
        z = decoder(images).squeeze(-1).squeeze(-1)  # Get the latent vector z
        z_vectors.append(z)
        class_labels.append(labels)

# Convert lists to tensors and numpy arrays
# Shape (num_samples, latent_dim)
z_vectors = torch.cat(z_vectors).cpu().numpy()
class_labels = torch.cat(class_labels).cpu().numpy()  # Shape (num_samples,)

# Apply t-SNE to reduce the latent vectors to 2D
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
z_tsne = tsne.fit_transform(z_vectors)

# Plot the t-SNE results
plt.figure(figsize=(8, 6))
# Using 'tab10' for 10 MNIST classes
scatter = plt.scatter(z_tsne[:, 0], z_tsne[:, 1],
                      c=class_labels, cmap='tab10', s=50)
plt.colorbar(scatter)
plt.title("t-SNE of Latent Vectors with Class Labels")
plt.xlabel("t-SNE component 1")
plt.ylabel("t-SNE component 2")
plt.show()

In [None]:
# configuraion for MLP and resnet
BATCH_SIZE = 128
EPOCHS = 200

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x.squeeze(-1).squeeze(-1)))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.out(x)
        return x

In [None]:

def prepare_datasets(img_dataset):
    train_size = int(0.8 * len(img_dataset))
    test_size = len(img_dataset) - train_size

    torch.manual_seed(42)

    train_dataset, test_dataset = random_split(
        img_dataset, [train_size, test_size])

    return train_dataset, test_dataset

test_dataset, train_dataset = prepare_datasets(img_dataset)
train_loader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def test_classifier(model, num_classes, test_loader, fn=lambda img: img):
    model.eval()
    top1_correct = 0
    top5_correct = 0
    total = 0
    true_positives = torch.zeros(num_classes).to(device)
    false_positives = torch.zeros(num_classes).to(device)
    false_negatives = torch.zeros(num_classes).to(device)

    with torch.inference_mode():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(fn(images))
            _, predicted = outputs.topk(
                5, 1, largest=True, sorted=True)
            predicted = predicted.t()
            top1_correct += (predicted[0] == labels).sum().item()
            top5_correct += (predicted ==
                             labels.unsqueeze(0)).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i]
                pred = predicted[0][i]
                if pred == label:
                    true_positives[label] += 1
                else:
                    false_positives[pred] += 1
                    false_negatives[label] += 1

    # Calculate precision, recall, and F1 for each class
    precision_per_class = true_positives / \
        (true_positives + false_positives + 1e-10)  # Avoid division by zero
    recall_per_class = true_positives / \
        (true_positives + false_negatives + 1e-10)

    # Average precision and recall over all classes
    precision = precision_per_class.mean().item()
    recall = recall_per_class.mean().item()
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

    top1_accuracy = 100 * top1_correct / total
    top5_accuracy = 100 * top5_correct / total
    return top1_accuracy, top5_accuracy, f1

In [None]:
num_classes = len(img_dataset.classes)
mlp = MLP(z_dim, 128, len(img_dataset.classes)).to(device)
criterion = nn.CrossEntropyLoss()
mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)

for epoch in trange(EPOCHS):
    mlp.train()
    for (X, y) in train_loader:
        X = X.to(device)
        y = y.to(device)
        mlp_optimizer.zero_grad()
        noise = decoder(X)
        outputs = mlp(noise.detach())
        loss = criterion(outputs, y)
        loss.backward()
        mlp_optimizer.step()

    if epoch % 10 == 0:
        top1_accuracy, top5_accuracy, f1 = test_classifier(
            model=mlp, num_classes=num_classes, test_loader=test_loader, fn=lambda image: decoder(image))
        print(f'Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f} Top1: {top1_accuracy} Top5: {top5_accuracy}, F1: {f1}')