# Import Libraries

In [None]:
import os
import math
import itertools
import random
import glob
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image

import shutil
import imageio
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Load Data

In [None]:
class CelebADataset(Dataset):
    def __init__(self, root, csv_file, transforms_=None, mode="train", attributes=None):
        self.transform = transforms.Compose(transforms_)
        self.selected_attrs = attributes
        self.files = sorted(glob.glob("%s/*.jpg" % root))
        self.files = self.files[:-2000] if mode == "train" else self.files[-2000:]
        self.annotations = self.get_annotations(csv_file)

    def get_annotations(self, csv_file):
        annotations = {}
        df = pd.read_csv(csv_file)
        self.label_names = df.columns[1:]
        
        for _, row in df.iterrows():
            filename = row[0]
            labels = []
            for attr in self.selected_attrs:
                idx = self.label_names.get_loc(attr)
                labels.append(1 * (row[idx + 1] == 1))
            annotations[filename] = labels
        return annotations

    def __getitem__(self, index):
        filepath = self.files[index % len(self.files)]
        filename = os.path.basename(filepath)
        img = self.transform(Image.open(filepath))
        label = self.annotations[filename]
        label = torch.FloatTensor(np.array(label))

        return img, label

    def __len__(self):
        return len(self.files)

In [None]:
train_transforms = [
    transforms.Resize(int(1.12 * 128), Image.BICUBIC),
    transforms.RandomCrop(128),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
]

In [None]:
valid_transforms = [
    transforms.Resize((128, 128), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [None]:
train_path = "/mnt/d/Datasets/celeba/img_align_celeba/img_align_celeba/"
valid_path = "/mnt/d/Datasets/celeba/img_align_celeba/img_align_celeba/"

In [None]:
attributes_file = "/mnt/d/Datasets/celeba/list_attr_celeba.csv"

In [None]:
selected_attributes = ["Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"]

In [None]:
c_dim = len(selected_attributes)
img_shape = (3, 128, 128)

In [None]:
train_dataloader = DataLoader(
    CelebADataset(
        train_path, attributes_file, transforms_=train_transforms, mode="train", attributes=selected_attributes
    ),
    batch_size=16,
    shuffle=True,
    num_workers=4,
)

In [None]:
valid_dataloader = DataLoader(
    CelebADataset(
        valid_path, attributes_file, transforms_=valid_transforms, mode="val", attributes=selected_attributes
    ),
    batch_size=10,
    shuffle=True,
    num_workers=1,
)

# Visualization

In [None]:
data_iter = iter(train_dataloader)
images, _ = next(data_iter)

def imshow(img):
    npimg = img.numpy()
    plt.imshow(npimg[0])
    plt.axis('off')
    plt.show()

imshow(make_grid(images, nrow=8))

# Network Initialization

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, img_shape, c_dim):
        super(Generator, self).__init__()
        channels, img_size, _ = img_shape

        self.model = nn.Sequential(
            # Initial convolution block
            nn.Conv2d(channels + c_dim, 64, 7, stride=1, padding=3, bias=False),
            nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            # Downsampling
            nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            # Residual blocks
            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),

            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),

            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(256, affine=True, track_running_stats=True),

            # Upsampling
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            # Output layer
            nn.Conv2d(64, channels, 7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x, c):
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x, c), 1)
        return self.model(x)

In [None]:
generator = Generator(img_shape=img_shape, c_dim=c_dim)
generator.to(device)
generator.apply(weights_init_normal)

# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape=(3, 128, 128), c_dim=5, n_strided=6):
        super(Discriminator, self).__init__()
        channels, img_size, _ = img_shape

        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),

            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),

            nn.Conv2d(512, 1024, 4, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),

            nn.Conv2d(1024, 2048, 4, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),
        )

        self.adv_out = nn.Conv2d(2048, 1, 3, padding=1, bias=False)

        kernel_size = img_size // (2 ** n_strided)
        self.cls_out = nn.Conv2d(2048, c_dim, kernel_size, bias=False)

    def forward(self, img):
        feature_repr = self.model(img)
        out_adv = self.adv_out(feature_repr)
        out_cls = self.cls_out(feature_repr).view(feature_repr.size(0), -1)
        return out_adv, out_cls

In [None]:
discriminator = Discriminator(img_shape=img_shape, c_dim=c_dim)
discriminator.to(device)
discriminator.apply(weights_init_normal)

# Train

In [None]:
criterion_cycle = nn.L1Loss()

In [None]:
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.995))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.995))

In [None]:
lambda_cls = 1
lambda_rec = 10
lambda_gp = 10
n_critic = 5

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples, lambda_gp=10.0):
    alpha = torch.tensor(np.random.random((real_samples.size(0), 1, 1, 1)), dtype=torch.float).to(device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples))
    interpolates.requires_grad_(True)
    d_interpolates, _ = D(interpolates)
    fake = torch.ones(d_interpolates.size(), device=device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradients_norm - 1) ** 2).mean()
    return gradient_penalty

In [None]:
n_epochs = 50

In [None]:
os.makedirs("./stargan", exist_ok=True)

In [None]:
d_loss_adv = []
d_loss_cls = []
d_loss_s = []

g_loss_adv = []
g_loss_cls = []
g_loss_s = []

In [None]:
label_changes = [
    ((0, 1), (1, 0), (2, 0)),  # Set to black hair
    ((0, 0), (1, 1), (2, 0)),  # Set to blonde hair
    ((0, 0), (1, 0), (2, 1)),  # Set to brown hair
    ((3, -1),),  # Flip gender
    ((4, -1),),  # Age flip
]

In [None]:
for epoch in range(n_epochs):
    for batch_idx, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        sampled_c = torch.randint(0, 2, (images.size(0), c_dim), dtype=torch.float32, device=device)

        fake_images = generator(images, sampled_c)

        discriminator_optimizer.zero_grad()

        real_validity, pred_cls = discriminator(images)
        
        fake_validity, _ = discriminator(fake_images.detach())
        
        gradient_penalty = compute_gradient_penalty(discriminator, images, fake_images, lambda_gp)
        
        loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        
        loss_D_cls = F.binary_cross_entropy_with_logits(pred_cls, labels, size_average=False) / pred_cls.size(0)
        
        loss_D = loss_D_adv + lambda_cls * loss_D_cls

        loss_D.backward()
        discriminator_optimizer.step()

        if batch_idx % n_critic == 0:
            generator_optimizer.zero_grad()

            generated_images = generator(images, sampled_c)
            recov_imgs = generator(generated_images, labels)
            
            fake_validity, pred_cls = discriminator(generated_images)
            
            loss_G_adv = -torch.mean(fake_validity)
            
            loss_G_cls = F.binary_cross_entropy_with_logits(pred_cls, sampled_c, size_average=False) / pred_cls.size(0)
            
            loss_G_rec = criterion_cycle(recov_imgs, images)
            
            loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec

            loss_G.backward()
            generator_optimizer.step()

        if batch_idx % 12538 == 0:
            d_loss_adv.append(loss_D_adv.item())
            d_loss_cls.append(loss_D_cls.item())
            d_loss_s.append(loss_D.item())

            g_loss_adv.append(loss_G_adv.item())
            g_loss_cls.append(loss_G_cls.item())
            g_loss_s.append(loss_G.item())
            
            
            print(f"[Epoch {epoch+1}/{n_epochs}] [Batch {batch_idx}/{len(train_dataloader)}] [D loss: {loss_D:.6f}] [G loss: {loss_G:.6f}]")

            valid_images, valid_labels = next(iter(valid_dataloader))
            
            valid_images = valid_images.to(device)
            valid_labels = valid_labels.to(device)
            
            img_samples = None
            
            for i in range(10):
                img, label = valid_images[i], valid_labels[i]
                
                images = img.repeat(c_dim, 1, 1, 1)
                labels = label.repeat(c_dim, 1)
                
                for sample_i, changes in enumerate(label_changes):
                    for col, val in changes:
                        labels[sample_i, col] = 1 - labels[sample_i, col] if val == -1 else val
                
                generated_images = generator(images, labels)
                generated_images = torch.cat([x for x in generated_images], -1)
                img_sample = torch.cat((img, generated_images), -1)
                
                img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)
            
            save_image(img_samples.view(1, *img_samples.shape), f"./stargan/epoch_{epoch}.png", normalize=True)