In [3]:
!pip install tqdm
!pip install numpy
!pip install torch
!pip install matplotlib
!pip install torchvision
!pip install torchmetrics[image]

You should consider upgrading via the '/home/j/Desktop/Programming/DeepLearning/GANs/venv/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/j/Desktop/Programming/DeepLearning/GANs/venv/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/j/Desktop/Programming/DeepLearning/GANs/venv/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/j/Desktop/Programming/DeepLearning/GANs/venv/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/j/Desktop/Programming/DeepLearning/GANs/venv/bin/python3 -m pip install --upgrade pip' command.[0m[33m
You should consider upgrading via the '/home/j/Desktop/Programming/DeepLearning/GANs/venv/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

In [4]:
import os
import torch
import shutil
import torchvision
from tqdm import tqdm
import torch.nn as nn
from typing import Dict, Any
from natsort import natsorted
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance


In [5]:

# Load celebA dataset
def get_dataloader(batch_size, shuffle, n_workers, image_size):
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(image_size),
            torchvision.transforms.CenterCrop(image_size),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    dataset = torchvision.datasets.CelebA(
        root="./data", transform=transform, download=True
    )
    # dataset = torch.utils.data.Subset(dataset, [i for i in range(128)])
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers
    )
    return dataloader


In [6]:
class SelfAttentionConv(nn.Module):
    def __init__(self, in_d, downscale_factor=8):
        super().__init__()
        """
        Downscale factor is suggested in the paper to reduce memory consumption
        they use 8, but 4 or 2 can be used with enough gpu memory
        """
        self.downscale_factor = downscale_factor
        self.k_conv = nn.utils.spectral_norm(
            nn.Conv2d(in_d, in_d // self.downscale_factor, 1, 1, 0)
        )
        self.q_conv = nn.utils.spectral_norm(
            nn.Conv2d(in_d, in_d // self.downscale_factor, 1, 1, 0)
        )
        self.v_conv = nn.utils.spectral_norm(nn.Conv2d(in_d, in_d, 1, 1, 0))
        # gamma is a learnable parameter (used in original paper)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.out_conv = nn.utils.spectral_norm(nn.Conv2d(in_d, in_d, 1, 1, 0))

    def forward(self, x):
        """
        x: (batch_size, in_d, h, w)
        """
        batch_size, in_d, h, w = x.shape
        # Embed input and reshape for matrix multiplication
        k = self.k_conv(x).view(batch_size, -1, h * w)
        q = self.q_conv(x).view(batch_size, -1, h * w)
        v = self.v_conv(x).view(batch_size, -1, h * w)
        # (batch_size, h * w, h * w)
        attn = torch.bmm(k.transpose(-2, -1), q)
        attn = F.softmax(attn, dim=-1)
        # (batch_size, in_d, h * w)
        out = torch.bmm(v, attn.transpose(-2, -1))
        out = out.view(batch_size, in_d, h, w)
        out = self.out_conv(out)
        # Add residual connection and scale by learnable parameter gamma
        out = self.gamma * out + x
        return out, attn


# Generator, DC-GAN with self attention
class Generator(nn.Module):
    def __init__(self, noise_dim=100, emb_dim=64):
        super().__init__()
        self.layer_1 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(noise_dim, emb_dim * 8, 4)),
            nn.BatchNorm2d(emb_dim * 8),
            nn.ReLU(),
        )
        self.layer_2 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.ConvTranspose2d(emb_dim * 8, emb_dim * 4, 4, 2, 1)
            ),
            nn.BatchNorm2d(emb_dim * 4),
            nn.ReLU(),
        )
        self.layer_3 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.ConvTranspose2d(emb_dim * 4, emb_dim * 2, 4, 2, 1)
            ),
            nn.BatchNorm2d(emb_dim * 2),
            nn.ReLU(),
        )
        self.layer_4 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(emb_dim * 2, emb_dim, 4, 2, 1)),
            nn.BatchNorm2d(emb_dim),
            nn.ReLU(),
        )

        self.layer_5 = nn.Sequential(nn.ConvTranspose2d(emb_dim, 3, 4, 2, 1), nn.Tanh())

        self.self_attn_1 = SelfAttentionConv(emb_dim * 2)
        self.self_attn_2 = SelfAttentionConv(emb_dim)

    def forward(self, noise):
        x = self.layer_1(noise)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x, attn1 = self.self_attn_1(x)
        x = self.layer_4(x)
        x, attn2 = self.self_attn_2(x)
        x = self.layer_5(x)
        return x, attn1, attn2


 # Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_d=3, emb_dim=64):
        super().__init__()
        self.layer_1 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_d, emb_dim * 2, 4, 2, 1)),
            nn.BatchNorm2d(emb_dim * 2),
            nn.ReLU(),
        )
        self.layer_2 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(emb_dim * 2, emb_dim * 4, 4, 2, 1)),
            nn.BatchNorm2d(emb_dim * 4),
            nn.ReLU(),
        )
        self.layer_3 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(emb_dim * 4, emb_dim * 4, 4, 2, 1)),
            nn.BatchNorm2d(emb_dim * 4),
            nn.ReLU(),
        )
        self.layer_4 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(emb_dim * 4, emb_dim * 8, 4, 2, 1)),
            nn.BatchNorm2d(emb_dim * 8),
            nn.ReLU(),
        )
        self.layer_5 = nn.Conv2d(emb_dim * 8, 1, 4, 1, 0)

        self.self_attn_1 = SelfAttentionConv(emb_dim * 4)
        self.self_attn_2 = SelfAttentionConv(emb_dim * 8)

    def forward(self, image):
        x = self.layer_1(image)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x, attn1 = self.self_attn_1(x)
        x = self.layer_4(x)
        x, attn2 = self.self_attn_2(x)
        x = self.layer_5(x)
        return x, attn1, attn2

In [7]:
class Trainer:
    def __init__(
        self,
        n_ckpt_steps: int,
        n_log_steps: int,
        epochs: int,
        # Data parameters
        batch_size: int,
        # Optimiser parameters
        g_lr: float,
        d_lr: float,
        # Model parameters
        noise_d: int,
        output_image_dims: int,
        flush_prev_logs: bool = True,
        adv_loss: str = "hinge",
        ckpt_path: str = None,
    ):
        """
        n_ckpt_steps: Saves a checkpoint every n_ckpt_steps
        n_log_steps: Logs every n_log_steps
        epochs: Number of epochs to train for
        batch_size: Batch size
        lr: Learning rate
        noise_d: Noise dimension
        emb_d: Embedding dimension
        output_image_dims: Output size ie: height and width of the image
        flush_prev_logs: Flushes previous logs and checkpoints
        """

        torch.cuda.empty_cache()
        # Initialize models
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator = Generator().to(self.device)
        self.discriminator = Discriminator().to(self.device)
        if ckpt_path:
            print(f"Loading checkpoint from {ckpt_path}")
            self.generator.load_state_dict(torch.load(ckpt_path)["generator"])
            self.discriminator.load_state_dict(torch.load(ckpt_path)["discriminator"])

        # Print the parameter count for each model
        print(
            f"Generator has {self.count_parameters(self.generator):,} trainable parameters"
        )
        print(
            f"Discriminator has {self.count_parameters(self.discriminator):,} trainable parameters"
        )
        # Initialize optimisers
        self.optim_G = torch.optim.Adam(
            self.generator.parameters(),
            lr=g_lr,
        )
        self.optim_D = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=d_lr,
        )
        self.train_dataloader = get_dataloader(batch_size, True, 4, output_image_dims)
        # Hyperparameters
        self.n_log_steps = n_log_steps
        self.n_ckpt_steps = n_ckpt_steps
        self.epochs = epochs
        self.batch_size = batch_size
        self.noise_d = noise_d
        self.adv_loss = adv_loss
        # Init globals & Metrics
        self.global_step = 0
        self.inception = InceptionScore(normalize=True)
        self.fid = FrechetInceptionDistance(normalize=True)
        self.g_losses, self.d_losses = [], []
        # Reset logs
        if flush_prev_logs:
            shutil.rmtree("results/logs", ignore_errors=True)
        # Create log directory
        os.makedirs("results/logs", exist_ok=True)
        self.log_dir = f'results/logs/train-run-{len(os.listdir("results/logs")) + 1}'
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(f"{self.log_dir}/images", exist_ok=True)
        os.makedirs(f"{self.log_dir}/metrics", exist_ok=True)
        os.makedirs(f"{self.log_dir}/checkpoints", exist_ok=True)
        print(f"Logging to {self.log_dir}")

    def reset_grads(self):
        self.optim_G.zero_grad()
        self.optim_D.zero_grad()

    def count_parameters(self, model: nn.Module):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def to_device(self, *args):
        return [arg.to(self.device) for arg in args]

    def run(self):
        for e in range(self.epochs):
            self.train_iter(e)
        # Save final model
        torch.save(self.generator.state_dict(), f"{self.log_dir}/generator.pt")
        # Save losses
        torch.save(self.g_losses, f"{self.log_dir}/g_losses.pt")
        torch.save(self.d_losses, f"{self.log_dir}/d_losses.pt")
        self.display_metrics(f"{self.log_dir}/metrics")

    def train_iter(self, epoch: int):
        self.generator.train()
        self.discriminator.train()
        # Train on a batch of images
        for i, (real_img, tags) in enumerate(
            tqdm(self.train_dataloader, desc=f"Epoch: {epoch}", leave=False)
        ):
            real_img, tags = real_img.to(self.device), tags.to(self.device)
            # Train the discriminator
            noise = torch.randn((self.batch_size, self.noise_d, 1, 1)).to(self.device)
            fake_image, _, _ = self.generator(noise)
            loss_D = self.discriminator_step(real_img, fake_image)
            # Train the generators
            loss_G = self.generator_step()

            # Log the losses & Metrics
            if self.global_step % self.n_log_steps == 0:
                self.g_losses.append(loss_G.item())
                self.d_losses.append(loss_D.item())
                self.log_image(real_img, f"{self.global_step}_real_image.png")
                self.log_image(
                    self.generator(noise)[0], f"{self.global_step}_fake_image.png"
                )
                # Compute inception score
                self.inception.update((self.generator(noise)[0].detach().cpu()))
                # Compute FID
                self.fid.update(self.generator(noise)[0].detach().cpu(), False)
                self.fid.update(real_img.cpu(), True)
                # Compute discriminator accuracy
                real_pred_D, _, _ = self.discriminator(real_img)
                fake_pred_D, _, _ = self.discriminator(fake_image)
                real_acc = ((real_pred_D > 0.5).float().mean()).item()
                fake_acc = ((fake_pred_D < 0.5).float().mean()).item()
                # Log the metrics
                self.log_metrics(
                    {
                        "Generator Loss": loss_G.item(),
                        "Discriminator Loss": loss_D.item(),
                        "Inception Score": self.inception.compute(),
                        "FID": self.fid.compute().item(),
                        "Real Accuracy": real_acc,
                        "Fake Accuracy": fake_acc,
                    }
                )
            # Increment the global step
            self.global_step += 1
            # Save a checkpoint
            if self.global_step % self.n_ckpt_steps == 0:
                self.save_checkpoint()

    def discriminator_step(self, real_img, generated_img):
        if self.adv_loss == "wgan":
            # Compute the discriminator loss
            real_pred_D, _, _ = self.discriminator(real_img)
            fake_pred_D, _, _ = self.discriminator(generated_img)
            d_loss = -torch.mean(real_pred_D) + torch.mean(fake_pred_D)
            self.reset_grads()
            d_loss.backward()
            self.optim_D.step()

            # Compute the gradient penalty note that the code for calculating gradient penalty is inspired by this tensorflow repo:
            # https://github.com/taki0112/Self-Attention-GAN-Tensorflow/blob/6c073a4c8bf9898ab3ed9b470451f5630cd05373/SAGAN.py#L6
            alpha = (
                torch.rand(real_img.size(0), 1, 1, 1)
                .to(self.device)
                .expand_as(real_img)
            )
            interpolated = Variable(
                (alpha * real_img) + (1 - alpha) * real_img, requires_grad=True
            )
            out, _, _ = self.discriminator(interpolated)

            grad = torch.autograd.grad(
                outputs=out,
                inputs=interpolated,
                grad_outputs=torch.ones(out.size()).to(self.device),
                retain_graph=True,
                create_graph=True,
                only_inputs=True,
            )[0]

            grad = grad.view(grad.size(0), -1)
            grad_norm = torch.sqrt(torch.sum(grad**2, dim=1))
            d_loss_gp = torch.mean((grad_norm - 1) ** 2)
            # Backward + Optimize
            d_loss = 10 * d_loss_gp
            self.reset_grads()
            d_loss.backward()
            self.optim_D.step()
            return d_loss

        elif self.adv_loss == "hinge":
            real_pred_D, _, _ = self.discriminator(real_img)
            fake_pred_D, _, _ = self.discriminator(generated_img)
            loss_D = F.relu(1 - real_pred_D).mean() + F.relu(1 + fake_pred_D).mean()
            return loss_D
        elif self.adv_loss == "lsgan":
            real_pred_D, _, _ = self.discriminator(real_img)
            fake_pred_D, _, _ = self.discriminator(generated_img)
            real_loss = F.mse_loss(real_pred_D, torch.ones_like(real_pred_D))
            fake_loss = F.mse_loss(fake_pred_D, torch.zeros_like(fake_pred_D))
            return real_loss + fake_loss

    # For generator loss
    def generator_step(self):
        noise_x = torch.randn(self.batch_size, self.noise_d, 1, 1).to(self.device)
        fake_image, _, _ = self.generator(noise_x)
        fake_pred_D, _, _ = self.discriminator(fake_image)
        g_loss = -fake_pred_D.mean()
        self.reset_grads()
        g_loss.backward()
        self.optim_G.step()
        return g_loss

    def save_checkpoint(self):
        torch.save(
            {
                "generator": self.generator.state_dict(),
                "discriminator": self.discriminator.state_dict(),
                "optim_G": self.optim_G.state_dict(),
                "optim_D": self.optim_D.state_dict(),
            },
            f"{self.log_dir}/checkpoints/step-{self.global_step}.pt",
        )

    def load_checkpoint(self, path: str):
        state_dict = torch.load(path)
        self.generator.load_state_dict(state_dict["generator"])
        self.discriminator.load_state_dict(state_dict["discriminator"])
        self.optim_G.load_state_dict(state_dict["optim_G"])
        self.optim_D.load_state_dict(state_dict["optim_D"])

    def log_image(self, img: torch.Tensor, name: str):
        torchvision.utils.save_image(
            img,
            f"{self.log_dir}/images/{name}",
            normalize=True,
            range=(-1, 1),
        )

    def log_metrics(self, metrics: Dict[str, Any]):
        torch.save(
            metrics, f"{self.log_dir}/metrics/step-{self.global_step}-metrics.pt"
        )

    def display_metrics(self, metrics_dir: str):
        g_loss, d_loss = [], []
        real_acc, fake_acc = [], []
        for file in natsorted(os.listdir(metrics_dir)):
            metrics = torch.load(os.path.join(metrics_dir, file))
            g_loss.append(metrics["Generator Loss"])
            d_loss.append(metrics["Discriminator Loss"])
            real_acc.append(metrics["Real Accuracy"])
            fake_acc.append(metrics["Fake Accuracy"])
        fig, ax = plt.subplots(2, 2, figsize=(20, 10))
        ax[0, 0].plot(g_loss)
        ax[0, 0].set_title("Generator Loss")
        ax[0, 1].plot(d_loss)
        ax[0, 1].set_title("Discriminator Loss")
        ax[1, 0].plot(real_acc)
        ax[1, 0].set_title("Real Accuracy")
        ax[1, 1].plot(fake_acc)
        ax[1, 1].set_title("Fake Accuracy")
        plt.show()


In [8]:
Trainer(
    n_ckpt_steps=5000,
    n_log_steps=750,
    epochs=30,
    batch_size=32,
    g_lr=0.0001,
    d_lr=0.0004,
    noise_d=100,
    output_image_dims=64,  # H x W
    flush_prev_logs=False,
    adv_loss="wgan",
    ckpt_path=None,
).run()


Generator has 3,624,181 trainable parameters
Discriminator has 4,426,819 trainable parameters
Files already downloaded and verified




Logging to results/logs/train-run-19


                                                               

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/j/Desktop/Programming/DeepLearning/GANs/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
  File "/tmp/ipykernel_55120/3059889046.py", line 1, in <module>
    Trainer(
  File "/tmp/ipykernel_55120/3503854386.py", line 100, in run
    self.train_iter(e)
  File "/tmp/ipykernel_55120/3503854386.py", line 132, in train_iter
    self.inception.update((self.generator(noise)[0].detach().cpu()))
  File "/home/j/Desktop/Programming/DeepLearning/GANs/venv/lib/python3.8/site-packages/torchmetrics/metric.py", line 390, in wrapped_func
    update(*args, **kwargs)
  File "/home/j/Desktop/Programming/DeepLearning/GANs/venv/lib/python3.8/site-packages/torchmetrics/image/inception.py", line 139, in update
    features = self.inception(imgs)
  File "/home/j/Desktop/Programming/DeepLearning/GANs/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return handle
  File "/home/j/D

In [None]:
class InferenceEngine:
    def __init__(self, model_path: str):
        self.generator = Generator()
        self.generator.load_state_dict(torch.load(model_path)["generator"])

    def generate(self, output_path: str = "results/out.png"):
        noise = torch.randn((1, 100, 1, 1))
        image, _, _ = self.generator(noise)
        # Make directory to save to

        torchvision.utils.save_image(image, output_path, normalize=True, range=(-1, 1))
        print(f"Generated image saved at {output_path}")
