# Import Libraries

In [None]:
import torch
import torch.nn as nn
import torchvision
import os
import PIL
import pdb
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import wandb #pip install wandb

# Visualization Function

In [None]:
#wandb - weight and bias
def show(tensor, num=25, wandbactive=0, name=''):
    data = tensor.detach().cpu()
    grid = make_grid(data[:num], nrow=5).permute(1, 2, 0)
    
    #optional
    if(wandbactive==1):
        wandb.log({name:wandb.Image(grid.numpy().clip(0, 1))})
        
    plt.imshow(grid.clip(0, 1))
    plt.show()

# Parameters & Hyper-Parameters

In [None]:
data_path = './dataset/img_align_celeba/'
checkpt_path ='./Checkpoints/'
#critic = Discriminator
epochs = 10000
batch_size = 128
lr = 1e-4
z_dim=200
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
crit_cycles=5 #5 cycle train discriminator(critic) & 1 cycle train generator cuz discriminator is less powerful than generator (so need more training)
gen_losses=[]
crit_losses=[]
wandbact = 1 # 1 - to visualize in wandb site (optional)
last_epoch = 0 #contains last epoch saved

# Login & config wandb

In [None]:
# place your wandb API generated from the site in a file 'wandb_API.txt' in the working directory
f = open('wandb_API.txt')
key = f.read()
f.close()
wandb.login(key=key)

In [None]:
%%capture
exp_name = wandb.util.generate_id()
myrun = wandb.init(
        project='Face_GAN',
        group=exp_name,
        config={
            'optimizer':'sgd',
            'model':'wgan gp',
            'epoch':'1000',
            'batch_size':128
        }
)
config = wandb.config

In [None]:
print(exp_name)

# Dataset
https://www.kaggle.com/jessicali9530/celeba-dataset

In [None]:
class Dataset(Dataset):
    def __init__(self, path, size=128, limit=10000):
        self.sizes=[size, size]#width, height
        items, labels=[], []
        
        for data in os.listdir(path)[:limit]:
            item = os.path.join(path, data)
            items.append(item)
            labels.append(data)#labels not needed
        self.items = items
        self.labels = labels
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx): #PIL -> np -> tensor
        data = PIL.Image.open(self.items[idx]).convert('RGB') #178x218
        data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) #128x128x3
        data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3x128x128 for compatibility
        data = torch.from_numpy(data).div(255) # scaling
        return data, self.labels[idx]

In [None]:
dataset = Dataset(data_path, size=128, limit=10000)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
x, y = next(iter(dataloader))
show(x)

# Generator

Conv2d : `(n+2*pad-ks)//stride + 1`<br>
convTranspose2d : `(n-1)*stride - 2*padding + ks`

- n : width or height
- ks : kernel size
- ConvTranspose2d : in_channels, out_channels, kernel_size, stride=1, padding=0

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=64, d_dim=16):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim, d_dim*32, kernel_size=4, stride=1, padding=0), #200 ->512 ; 1x1 -> 4x4
            # begin with 1 x 1 image with z_dim channels (decrease channel and increase size)
            nn.BatchNorm2d(d_dim*32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(d_dim*32, d_dim*16, kernel_size=4, stride=2, padding=1), #512 -> 256 ; 4x4 -> 8x8
            nn.BatchNorm2d(d_dim*16),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(d_dim*16, d_dim*8, kernel_size=4, stride=2, padding=1), #256 -> 128 ; 8x8 -> 16x16
            nn.BatchNorm2d(d_dim*8),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(d_dim*8, d_dim*4, kernel_size=4, stride=2, padding=1), #128 -> 64 ; 16x16 -> 32x32
            nn.BatchNorm2d(d_dim*4),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(d_dim*4, d_dim*2, kernel_size=4, stride=2, padding=1), #64 -> 32 ; 32x32 ->64x64
            nn.BatchNorm2d(d_dim*2),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(d_dim*2, 3, kernel_size=4, stride=2, padding=1), #32 -> 3 ; 64x64 -> 128x128
            nn.Tanh() # out range [-1, 1]
        )
        
    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1) #128(batch size) x 200(latent space) x 1(width) x 1(height)
        return self.gen(x)

In [None]:
def gen_noise(num, z_dim, device=device):
    return torch.randn(num, z_dim, device=device) # 128 x 200    

# Critic

In [None]:
class Critic(nn.Module):
    def __init__(self, d_dim=16):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            nn.Conv2d(3, d_dim, 4, 2, 1), #128x128 -> 64x64 ; 3 -> 16
            nn.InstanceNorm2d(d_dim),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(d_dim, d_dim*2, 4, 2, 1), #64x64 -> 32x32 ; 16 -> 32
            nn.InstanceNorm2d(d_dim*2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(d_dim*2, d_dim*4, 4, 2, 1), #32x32 -> 16x16 ; 32 -> 64
            nn.InstanceNorm2d(d_dim*4),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(d_dim*4, d_dim*8, 4, 2, 1), #16x16 -> 8x8; 64 -> 128
            nn.InstanceNorm2d(d_dim*8),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(d_dim*8, d_dim*16, 4, 2, 1), #8x8 -> 4x4; 128->256
            nn.InstanceNorm2d(d_dim*16),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(d_dim*16, 1, 4, 1, 0) # 4x4 -> 1x1 ; 256->1
        )
    
    def forward(self, image):
        #image : 128(batch)x3(channel)x128x128(w,h)
        crit_pred = self.crit(image) #128(batch)x1(channel)x1x1(w,h)
        return crit_pred.view(len(crit_pred), -1) # 128 x 1

# Initialize Weights (Optional)

In [None]:
## optional, init your weights in different ways
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias,0)

    if isinstance(m,nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias,0)

##gen=gen.apply(init_weights)
##crit=crit.apply(init_weights)

# Model

In [None]:
# Models
gen = Generator(z_dim).to(device)
critic = Critic().to(device)

In [None]:
# Optimizers
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.9))
critic_opt = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.9))

In [None]:
#wandb
if(wandbact==1):
    wandb.watch(gen, log_freq=100)
    wandb.watch(critic, log_freq=100)

# Gradient Penalty Calculation

In [None]:
def get_grad_penalty(real, fake, critic, alpha, gamma=10):
    mix_images = real*alpha + fake*(1-alpha) #128(batch)x3(channel)x128x128(w,h)
    mix_scores = critic(mix_images) #128x1
    
    gradient = torch.autograd.grad(
        inputs = mix_images,
        outputs = mix_scores,
        grad_outputs = torch.ones_like(mix_scores),
        retain_graph=True,
        create_graph=True
    )[0] #128x3x128x128
    
    gradient = gradient.view(len(gradient), -1) #128x 49152(128*128*3)
    gradient_norm = gradient.norm(2, dim=1)# 2 - L2 norm only on 49152
    grad_penalty = gamma*((gradient_norm-1)**2).mean()
    
    return grad_penalty

# Save Checkpoints

In [None]:
def save_chckpt(name):
    torch.save({
        'epoch':epoch,
        'model_state_dict':gen.state_dict(),
        'optimizer_state_dict':gen_opt.state_dict()
    }, f"{checkpt_path}G-{name}.pkl")
    
    torch.save({
        'epoch':epoch,
        'model_state_dict':critic.state_dict(),
        'optimizer_state_dict':critic_opt.state_dict()
    }, f"{checkpt_path}Critic-{name}.pkl")
    
    print(f"Saved Checkpoint:\n\t Epoch : {epoch}")

# Load Checkpoint

In [None]:
def load_chckpt(name):
    checkpoint = torch.load(f"{checkpt_path}G-{name}.pkl")
    gen.load_state_dict(checkpoint['model_state_dict'])
    gen_opt.load_state_dict(checkpoint['optimizer_state_dict'])
    
    checkpoint = torch.load(f"{checkpt_path}Critic-{name}.pkl")
    critic.load_state_dict(checkpoint['model_state_dict'])
    critic_opt.load_state_dict(checkpoint['optimizer_state_dict'])
    
    last_epoch = checkpoint['epoch']

    print(f"Checkpoint Loaded:\n\t Epoch : {last_epoch}")
    return last_epoch

# Load From Previous Checkpoint

In [None]:
#last_epoch = load_chckpt("Checkpoint")

# Training Loop

In [None]:
pbar1 = tqdm(range(epochs))
pbar1.n = last_epoch
pbar1.refresh()
for epoch in range(last_epoch,epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real) #128
        real=real.to(device)
        
        #Critic
        mean_critic_loss = 0
        for i in range(crit_cycles): # 5 times critic 1 time generator
            critic_opt.zero_grad()
            
            noise = gen_noise(cur_batch_size, z_dim, device)
            fake = gen(noise)
            critic_fake_pred = critic(fake.detach())
            critic_real_pred = critic(real)
            
            alpha = torch.rand(len(real), 1,1,1, device=device, requires_grad=True) #128x1x1x1
            grad_penalty = get_grad_penalty(real, fake.detach(), critic, alpha)
            
            crit_loss = critic_fake_pred.mean() - critic_real_pred.mean() + grad_penalty
            
            mean_critic_loss+=crit_loss.item() / crit_cycles
            
            crit_loss.backward(retain_graph=True)
            critic_opt.step()
        crit_losses+=[mean_critic_loss]
        
        #generator
        gen_opt.zero_grad()
        noise = gen_noise(cur_batch_size, z_dim, device)
        fake = gen(noise)
        critic_fake_pred = critic(fake)
        
        gen_loss = -critic_fake_pred.mean()
        gen_loss.backward()
        gen_opt.step()
        
        gen_losses+=[gen_loss.item()]
        
    if(wandbact==1):
        wandb.log({'Epoch':epoch, 'Critic Loss':mean_critic_loss, 'Generator Loss':gen_loss})

    show(fake, wandbactive=1, name='fake')
    show(real, wandbactive=1, name='real')

    gen_mean = sum(gen_losses[-len(dataloader):]) / len(dataloader)
    critic_mean = sum(crit_losses[-len(dataloader):]) / len(dataloader)

    plt.plot(range(len(gen_losses)), torch.Tensor(gen_losses), label='Generator Loss')
    plt.plot(range(len(gen_losses)), torch.Tensor(crit_losses), label='Critic Loss')
    plt.ylim(-150, 150)
    plt.legend()
    plt.show()
    print(f"Epoch : {epoch}; Generator Loss : {gen_mean}; Critic Loss : {critic_mean}\n")

    save_chckpt("Checkpoint")
        
    pbar1.update()

`10000 / 128 = 78.125 steps`<br>
`50000 / 128 = 390.625 steps`

# Generate New Faces

In [None]:
noise = gen_noise(batch_size, z_dim)
fake = gen(noise)
show(fake)

In [None]:
plt.imshow(fake[4].detach().cpu().permute(1,2,0).squeeze().clip(0,1))
plt.show()