In [None]:
import os
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import datetime
from copy import deepcopy
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchsummary import summary
torch.manual_seed(69)
%load_ext tensorboard

In [None]:
# Hyperparameters
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
TEST_SUBSET_SIZE = 5000
LEARNING_RATE = 1e-4
IMAGE_SIZE = 128
CHANNELS_IMG = 3
Z_DIM = 256
NUM_WORKERS = 2
LOG_FOLDER = "logs/"
os.makedirs(LOG_FOLDER, exist_ok=True)

In [None]:
# For generators
fixed_test_vector = torch.randn(TEST_SUBSET_SIZE//BATCH_SIZE, BATCH_SIZE, Z_DIM, 1, 1).to(DEVICE)

In [None]:
def load_checkpoint(checkpoint_file, model, device='cuda'):
    checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint["state_dict"])
    print("=> Loaded checkpoint")

In [None]:
def evaluate_model(discriminator, generator, test_loader, fixed_test_vector, device='cuda'):
    """
    Evaluates a discriminator on both real and generated samples.

    Args:
        discriminator: Discriminator model in eval mode
        generator: Generator model
        test_loader: DataLoader for real samples
        fixed_test_vector: Latent vectors for generating fake samples
        device: Device to run evaluation on

    Returns:
        dict: Evaluation metrics
    """
    discriminator.eval()
    generator.eval()
    criterion = torch.nn.BCELoss()
    running_loss = 0.0
    preds, scores, targets = [], [], []

    with torch.no_grad():
        loop = tqdm(test_loader, leave=True, desc="Testing")
        for batch_idx, (real, real_labels) in enumerate(loop):
            real = real.to(device)
            real_labels = real_labels.to(device).float()
            real_outputs = discriminator(real).squeeze()
            real_loss = criterion(real_outputs, real_labels)

            fake = generator(fixed_test_vector[batch_idx])
            fake_labels = torch.zeros_like(real_labels).to(device).float()
            fake_outputs = discriminator(fake).squeeze()
            fake_loss = criterion(fake_outputs, fake_labels)

            loss = real_loss + fake_loss
            running_loss += loss.item()

            outputs = torch.cat((real_outputs, fake_outputs), dim=0).cpu().numpy()
            labels = torch.cat((real_labels, fake_labels), dim=0).cpu().numpy()
            scores = np.concatenate((scores, outputs))
            targets = np.concatenate((targets, labels))
            preds = (scores >= 0.5).astype(int) # Converting to discreet 1s and 0s instead of probability
            metrics = {
                'loss': running_loss / (batch_idx + 1),
                'accuracy': accuracy_score(targets, preds),
                'precision': precision_score(targets, preds, zero_division=1),
                'recall': recall_score(targets, preds, zero_division=1),
                'f1': f1_score(targets, preds, zero_division=1),
                'roc_auc': roc_auc_score(targets, scores)
            }

            loop.set_postfix(**metrics)

    final_metrics = {
        k: v/len(test_loader) if k == 'loss' else v
        for k, v in metrics.items()
    }

    print(f"Test: Loss: {final_metrics['loss']:.4f} | "
          f"Accuracy: {final_metrics['accuracy']:.4f} | "
          f"Precision: {final_metrics['precision']:.4f} | "
          f"Recall: {final_metrics['recall']:.4f} | "
          f"F1: {final_metrics['f1']:.4f} | "
          f"ROC AUC: {final_metrics['roc_auc']:.4f}")

    return final_metrics

## ProGAN

In [None]:
CHECKPOINT_PRO_CRITIC = "models/pro_critic_128_2.pth"
CHECKPOINT_PRO_GEN = "models/pro_generator_128_2.pth"
CHECKPOINT_PRO_DISC = "models/pro_disc_128_3.pth"
PRO_LR = 5.5e-4
FACTORS = [1, 1/2, 1/4, 1/8, 1/16, 1/32]
CHANNELS_IMG = 3
IN_CHANNELS = 512

In [None]:
class WSConv2d(nn.Module):
    """
    Weight scaled Conv2d (Equalized Learning Rate)
    Note that input is multiplied rather than changing weights
    this will have the same result.

    Inspired and looked at:
    https://github.com/nvnbny/progressive_growing_of_gans/blob/master/modelUtils.py
    """

    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()

        # initial takes 1x1 -> 4x4
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        # Converts n channel image to rgb channel
        self.initial_vec_to_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )

        self.prog_blocks, self.vec_to_rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_vec_to_rgb]),
        )

        for i in range(
            len(FACTORS) - 1
        ):  # -1 to prevent index error because of factors[i+1]
            conv_in_c = int(in_channels * FACTORS[i])
            conv_out_c = int(in_channels * FACTORS[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.vec_to_rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * generated + (1 - alpha) * upscaled

    def forward(self, x, alpha=0.5, steps=5):
        out = self.initial(x) # 1x1 to 4x4

        if steps == 0:
            return self.initial_vec_to_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different vec_to_rgb_layers
        # (steps-1) and steps for upscaled, out respectively
        final_upscaled = self.vec_to_rgb_layers[steps - 1](upscaled)
        final_out = self.vec_to_rgb_layers[steps](out)
        return torch.tanh(self.fade_in(alpha, final_upscaled, final_out))

pro_gen = Generator(
    Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
).to(DEVICE)
opt_pro_gen = optim.Adam(
    pro_gen.parameters(), lr=1e-3, betas=(0.0, 0.99)
)

load_checkpoint(CHECKPOINT_PRO_GEN, pro_gen)

pro_gen.eval()

In [None]:
class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_to_vec_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        # here we work back ways from factors because the discriminator
        # should be mirrored from the generator. So the first prog_block and
        # rgb_to_vec layer we append will work for input size 1024x1024, then 512->256-> etc
        for i in range(len(FACTORS) - 1, 0, -1):
            conv_in = int(in_channels * FACTORS[i])
            conv_out = int(in_channels * FACTORS[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_to_vec_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # perhaps confusing name "initial_rgb_to_vec" this is just the RGBto_vec layer for 4x4 input size
        # did this to "mirror" the generator initial_rgb_to_vec
        self.initial_rgb_to_vec = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_to_vec_layers.append(self.initial_rgb_to_vec)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # down sampling using avg pool

        # this is the block for 4x4 input size
        self.final_block = nn.Sequential(
            # +1 to in_channels because we concatenate from MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),
            nn.Flatten(),
            nn.Sigmoid(),
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # we take the std for each example (across all channels, and pixels) then we repeat it
        # for a single channel and concatenate it with the image. In this way the discriminator
        # will get information about the variation in the batch/image
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha=0.5, steps=5):
        # where we should start in the list of prog_blocks, maybe a bit confusing but
        # the last is for the 4x4. So example let's say steps=1, then we should start
        # at the second to last because input_size will be 8x8. If steps==0 we just
        # use the final block
        cur_step = len(self.prog_blocks) - steps

        # convert from rgb_to_vec as initial step, this will depend on
        # the image size (each will have it's on rgb_to_vec layer)
        out = self.leaky(self.rgb_to_vec_layers[cur_step](x))

        if steps != 0:# i.e, image is anything other than 4x4
            # because prog_blocks might change the channels, for down scale we use rgb_to_vec_layer
            # from previous/smaller size which in our case correlates to +1 in the indexing
            downscaled = self.leaky(self.rgb_to_vec_layers[cur_step + 1](self.avg_pool(x)))
            out = self.avg_pool(self.prog_blocks[cur_step](out))

            # the fade_in is done first between the downscaled and the input
            # this is opposite from the generator
            out = self.fade_in(alpha, downscaled, out)

            for step in range(cur_step + 1, len(self.prog_blocks)):
                out = self.prog_blocks[step](out)
                out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out)

pro_disc = Discriminator(
    Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
).to(DEVICE)
opt_pro_disc = optim.Adam(
    pro_disc.parameters(), lr=1e-3, betas=(0.0, 0.99)
)

load_checkpoint(CHECKPOINT_PRO_CRITIC, pro_disc, opt_pro_disc, PRO_LR,)

for param in pro_disc.parameters():
    param.requires_grad = True

pro_disc.train()

## WGAN

In [None]:
CHECKPOINT_WGAN_CRITIC = "models/wgan_critic_128_2.pth"
CHECKPOINT_WGAN_GEN = "models/wgan_generator_128_2.pth"
CHECKPOINT_WGAN_DISC = "models/wgan_discriminator_128_2.pth"

FEATURES_GEN = 16
FEATURES_DISC = 16

In [None]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 32, 4, 1, 0),  # img: 4
            self._block(features_g * 32, features_g * 16, 4, 2, 1),  # img: 8
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 16
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 32
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 64
            self._block(features_g * 2, features_g * 1, 4, 2, 1),  # img: 128
            nn.Conv2d(
                features_g * 1, channels_img, kernel_size=3, stride=1, padding=1
            ),
            # Output: N x channels_img x 128 x 128
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

wgan_gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(DEVICE)
load_checkpoint(CHECKPOINT_WGAN_GEN, wgan_gen)
wgan_gen.eval()

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 128 x 128
            nn.Conv2d(channels_img, features_d, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            self._block(features_d * 8, features_d * 16, 4, 2, 1),
            self._block(features_d * 16, features_d * 32, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 32, features_d * 32, kernel_size=4, stride=1, padding=0, bias=False),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 32, 1, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

wgan_disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(DEVICE)
load_checkpoint(CHECKPOINT_WGAN_CRITIC, wgan_disc)
for param in wgan_disc.parameters():
    param.requires_grad = True
wgan_disc.train()

## Testing

In [None]:
load_checkpoint(CHECKPOINT_WGAN_DISC, wgan_disc)

In [None]:
wgan_disc_test_metrics = evaluate_model(
    discriminator=wgan_disc,
    generator=wgan_gen,
    test_loader=test_loader,
    fixed_test_vector=fixed_test_vector,
)