In [1]:
import os
import torch
import shutil
import numpy as np
import torchvision
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [19]:
# Load celebA dataset
def get_dataloader(batch_size, shuffle, n_workers, image_size):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_size),
        torchvision.transforms.CenterCrop(image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.CelebA(root='./data', transform=transform, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
    return dataloader

In [20]:
class SelfAttentionConv(nn.Module):
    def __init__(self, in_d, downscale_factor=8):
        super().__init__()
        """
        Downscale factor is suggested in the paper to reduce memory consumption
        they use 8, but you can use 4 or 2 if you have enough memory
        """
        self.downscale_factor = downscale_factor
        self.k_conv = nn.Conv2d(in_d, in_d // self.downscale_factor, 1, 1, 0)
        self.q_conv = nn.Conv2d(in_d, in_d // self.downscale_factor, 1, 1, 0)
        self.v_conv = nn.Conv2d(in_d, in_d, 1, 1, 0)
        # gamma is a learnable parameter (used in original paper)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.out_conv = nn.Conv2d(in_d, in_d, 1, 1, 0)

    def forward(self, x):
        """
        x: (batch_size, in_d, h, w)
        """
        batch_size, in_d, h, w = x.shape
        # Embed input and reshape for matrix multiplication
        k = self.k_conv(x).view(batch_size, -1, h * w)
        q = self.q_conv(x).view(batch_size, -1, h * w)
        v = self.v_conv(x).view(batch_size, -1, h * w)
        # (batch_size, h * w, h * w)
        attn = torch.bmm(k.transpose(-2, -1), q)
        attn = F.softmax(attn, dim=-1)
        # (batch_size, in_d, h * w)
        out = torch.bmm(v, attn.transpose(-2, -1))
        out = out.view(batch_size, in_d, h, w)
        out = self.out_conv(out)
        # Add residual connection and scale by learnable parameter gamma
        out = self.gamma * out + x
        return out, attn


# Generator, DC-GAN with self attention
class Generator(nn.Module):
    def __init__(self, noise_dim=64, target_output_size=256, emb_dim=1024):
        super().__init__()
        self.in_layer = nn.Sequential(
            nn.Conv2d(noise_dim, emb_dim, 3, 1, 1),
            nn.BatchNorm2d(emb_dim),
            nn.ReLU(),
        )

        self.layers = nn.ModuleList()
        n_layers = int(np.log2(target_output_size))
        for _ in range(n_layers-1):
            self.layers.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        emb_dim,
                        emb_dim // 2 if emb_dim // 2 >= 64 else emb_dim,
                        kernel_size=4,
                        stride=2,
                        padding=1,
                    ),
                    nn.BatchNorm2d(emb_dim // 2 if emb_dim // 2 >= 64 else emb_dim),
                    nn.ReLU(),
                )
            )
            # Don't shrink emb_dim below too low
            if emb_dim // 2 >= 64:
                emb_dim = emb_dim // 2
        self.layers = nn.Sequential(*self.layers)
        # Self attention sandwhiches the penultimate layer
        self.self_attn_1 = SelfAttentionConv(emb_dim)
        self.self_attn_2 = SelfAttentionConv(emb_dim)
        self.penultimate_layer = nn.Sequential(
            nn.ConvTranspose2d(emb_dim, emb_dim, 3, 1, 1),
            nn.BatchNorm2d(emb_dim),
            nn.ReLU(),
        )
        self.out_layer = nn.Sequential(nn.Conv2d(emb_dim, 3, 3, 1, 1), nn.Tanh())

    def forward(self, noise):
        out = self.in_layer(noise)
        # Main layers
        out = self.layers(out)
        # Attention layers
        out, _ = self.self_attn_1(out)
        out = self.penultimate_layer(out)
        out, _ = self.self_attn_2(out)
        # Out layers
        out = self.out_layer(out)
        return out



# Discriminator, adopts the PatchGAN architecture ie: output is a receptive field of n x n
class Discriminator(nn.Module):
    def __init__(self, in_d=3, emb_dim=256):
        super().__init__()
        self.in_layer = nn.Sequential(
            nn.Conv2d(in_d, emb_dim, 3, 1, 1),
            nn.BatchNorm2d(emb_dim),
            nn.ReLU(),
        )

        self.layers = nn.ModuleList()
        for _ in range(4):
            self.layers.append(
                nn.Sequential(
                    nn.Conv2d(
                        emb_dim,
                        emb_dim // 2 if emb_dim // 2 >= 64 else emb_dim,
                        kernel_size=4,
                        stride=2,
                        padding=1,
                    ),
                    nn.BatchNorm2d(emb_dim // 2 if emb_dim // 2 >= 64 else emb_dim),
                    nn.ReLU(),
                )
            )
            # Don't shrink emb_dim below too low
            if emb_dim // 2 >= 64:
                emb_dim = emb_dim // 2
        self.layers = nn.Sequential(*self.layers)
        # Self attention sandwhiches the penultimate layer
        self.self_attn_1 = SelfAttentionConv(emb_dim)
        self.self_attn_2 = SelfAttentionConv(emb_dim)
        self.penultimate_layer = nn.Sequential(
            nn.Conv2d(emb_dim, emb_dim, 3, 1, 1),
            nn.BatchNorm2d(emb_dim),
            nn.ReLU(),
        )
        self.out_layer = nn.Sequential(nn.Conv2d(emb_dim, 1, 3, 1, 1), nn.Sigmoid())

    def forward(self, noise):
        out = self.in_layer(noise)
        # Main layers
        out = self.layers(out)
        # Attention layers
        out, _ = self.self_attn_1(out)
        out = self.penultimate_layer(out)
        out, _ = self.self_attn_2(out)
        # Out layers
        out = self.out_layer(out)
        return out

In [27]:
class Trainer:
    def __init__(
        self,
        n_ckpt_steps: int,
        n_log_steps: int,
        epochs: int,
        # Data parameters
        batch_size: int,
        # Optimiser parameters
        lr: float,
        beta_1: float,
        beta_2: float,
        # Model parameters
        noise_d: int,
        emb_d: int,
        output_size: int,
        flush_prev_logs: bool = True,
    ):
        """
        n_ckpt_steps: Saves a checkpoint every n_ckpt_steps
        epochs: Number of epochs to train for
        data_dir: Directory containing the data, must contain trainA and trainB folders
        batch_size: Batch size
        lr: Learning rate
        beta_1: Beta 1 for Adam optimiser
        beta_2: Beta 2 for Adam optimiser
        """
        if flush_prev_logs:
            shutil.rmtree("results/logs", ignore_errors=True)
        os.makedirs("results/logs", exist_ok=True)
        torch.cuda.empty_cache()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator = Generator(
            noise_dim=noise_d, emb_dim=emb_d, target_output_size=output_size
        ).to(self.device)
        self.discriminator = Discriminator(in_d=3, emb_dim=emb_d).to(self.device)
        self.optim_G = torch.optim.Adam(
            self.generator.parameters(),
            lr=lr,
            betas=(beta_1, beta_2),
        )
        self.optim_D = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=lr,
            betas=(beta_1, beta_2),
        )
        self.train_dataloader = get_dataloader(batch_size, True, 4, output_size)
        # Hyperparameters
        self.n_log_steps = n_log_steps
        self.n_ckpt_steps = n_ckpt_steps
        self.epochs = epochs
        self.batch_size = batch_size
        self.noise_d = noise_d
        # Init variables
        self.global_step = 0

    def to_device(self, *args):
        return [arg.to(self.device) for arg in args]

    def run(self):
        for e in range(self.epochs):
            self.train_iter(e)

    def train_iter(self, epoch: int):
        self.generator.train()
        self.discriminator.train()
        # Train on a batch of images
        for i, real_img in enumerate(
            tqdm(self.train_dataloader, desc=f"Epoch: {epoch}", leave=False)
        ):
            # Train the generators
            self.optim_G.zero_grad()
            loss_G = self.compute_generator_loss()
            loss_G.backward()
            self.optim_G.step()
            # Train the discriminators
            self.optim_D.zero_grad()
            noise = torch.randn((self.batch_size, 64, 1, 1)).to(self.device)
            fake_image = self.generator(noise)
            loss_D = self.compute_discriminator_loss(real_img, fake_image)
            loss_D.backward()
            self.optim_D.step()
            # Log the losses
            if self.global_step % self.n_log_steps == 0:
                self.log_image(real_img, f"{self.global_step}_real_image.png")
                self.log_image(
                    self.generator(noise), f"{self.global_step}_fake_image.png"
                )

            # Increment the global step
            self.global_step += 1
            if self.global_step % self.n_ckpt_steps == 0:
                self.save_checkpoint()

    def compute_generator_loss(self) -> torch.Tensor:
        noise_x = torch.randn(self.batch_size, self.noise_d, 1, 1).to(self.device)
        fake_image = self.generator(noise_x)
        fake_pred_D = self.discriminator(fake_image)
        loss_G = F.mse_loss(fake_pred_D, torch.ones_like(fake_pred_D))
        return loss_G

    def compute_discriminator_loss(
        self,
        real_img: torch.Tensor,
        generated_img: torch.Tensor,
    ) -> torch.Tensor:
        print(real_img.shape, generated_img.shape)
        real_pred_D = self.discriminator(real_img)
        real_loss_D = F.mse_loss(
            real_pred_D, torch.ones_like(real_pred_D)
        )
        fake_pred_D = self.discriminator(generated_img)
        fake_loss_D = F.mse_loss(
            fake_pred_D, torch.zeros_like(fake_pred_D)
        )
        loss_D = (real_loss_D + fake_loss_D) / 2
        return loss_D

    def save_checkpoint(self):
        torch.save(
            {
                "generator": self.generator.state_dict(),
                "discriminator": self.discriminator.state_dict(),
                "optim_G": self.optim_G.state_dict(),
                "optim_D": self.optim_D.state_dict(),
            },
            f"results/checkpoints/{self.global_step}.pt",
        )

    def load_checkpoint(self, path: str):
        state_dict = torch.load(path)
        self.generator.load_state_dict(state_dict["generator"])
        self.discriminator.load_state_dict(state_dict["discriminator"])
        self.optim_G.load_state_dict(state_dict["optim_G"])
        self.optim_D.load_state_dict(state_dict["optim_D"])

    def log_image(self, img: torch.Tensor, name: str):
        torchvision.utils.save_image(
            img,
            "results/logs" + name,
            normalize=True,
            range=(-1, 1),
        )


In [28]:
Trainer(
    n_ckpt_steps=1000,
    n_log_steps=100,
    epochs=100,
    batch_size=32,
    lr=0.0002,
    beta_1=0.5,
    beta_2=0.999,
    noise_d=64,
    emb_d=64,
    output_size=64,
).run()

Files already downloaded and verified


                                                  

AttributeError: 'list' object has no attribute 'shape'

In [None]:
# TODO
# - Add spectral norm