### Code Authors
Özgür Aslan 2236958 aslan.ozgur@metu.edu.tr  
Burak Bolat 2237097 burak.bolat@metu.edu.tr

### Paper Authors

Mahmoud Afifi, Marcus A. Brubaker, Michael S. Brown

### Paper Information
The paper we selected to implement is [HistoGAN: Controlling Colors of GAN-Generated and Real Images via Color Histograms](https://arxiv.org/abs/2011.11731)
The main idea of the paper is to use color histogram of target images to control the colors of the generated image without changing the high level features of the generated image (gender, having glasses, beard, hair style and objects in the background...)  
To accomplish this idea, they modify the StyleGAN2 architecture:
- In the last 2 style blocks, instead of using affine transformation of the w vector, they use the color histogram projected by a neural network.
- Different from mixing regularization, they use a histogram based loss.
- The histogram loss uses two different target images to compute color histograms and interpolate this histograms to obtain a new one. The interpolated histogram is given to generator network to generate a target image with colors controlled by the interpolated histogram. This way the authors try to prevent the generator network to overfit color histograms of the trained dataset.
- Due to hardware limitations the network does not generate 1024x1024 resolution images but generates 256x256 images.
- Also due to hardware limitations they use batches of size 2 with gradient accumulation.

![arch](materials/arch.png)

### Experimental Goals
This study implements only the HistoGAN part of the paper, not ReHistoGAN.
Our experimental result goal is to train HistoGAN to generate Anime Faces with controlled histograms and compute the FID scores given in the paper.

![exp1](materials/expg1.jpeg)
![exp2](materials/expg2.jpeg)

#### Histogram Computation
Computing Histogram is critical since it directly affects style of last 2 blocks. Authors used chrominance logarithm space. It normalizes each color channel with respect to other two channels in logarithmic space. In this chrominance space, there is u and v axes. That is, if we look at red channel's chrominance space, u is the normalization of red channel with respect to green and v is the normalization of red channel with respect to blue. Same holds for all color channels.  

After shifting RGB space to RGB-uv space, the histogram is computed as it is computationally efficient and more stable. Authors used 64 bin for the histogram which results in 64x64 histogram for u and v channel. We have 3 channels, namely red, green and blue, thus, overall the histogram is 3x64x64. Histogram is weighted with respect to pixel intesity, i.e. if a pixel has high RGB values its affect on the histogram bin is higher. Last difference of the histogram than histograms of previous works is kernels for computing bins. Authors do not used exact bin selection. Instead, they put a normalized pixel into a bin with respect to soft kernel. That means, if we have a red channel after normalized with respect to green and blue, we have some u and v values. Instead of just adding 1 (1 being chosen for simplicty, remember intensity multiplication) to the bin of H(u,v), they add values to the neighbour of (u,v) with the value after inverse quadratic kernel.

Histogram feature is computed like syle vector (w). It passed through the same neural network architecture with different parametes. More precisely, histogram passes through 8 layer MLP and outputs latent histogram vector size of 512.  

We put some computed histograms by us. Images taken from internet crawling.  

![asd](materials/gresized.png) ![asc](materials/ghist1.png)  
![asb](materials/rresized.png) ![asj](materials/rhist1.png)

#### Loss for Training with Histogram
Since the paper uses target histogram for generation, generated image should have close histogram to target. Thus a closeness measure Hellinger distance between histogram of generated and target images is computed and tried to minimize. You can check the losses belove.

Difference between histograms  
![l1](materials/hloss.png)

Total loss for generator  
![lt](materials/total_loss.png)

### Discriminator

Discriminator consist of residual blocks. There are log_2(N)-1 such bloks where N is image resolution, to be spesific 256. As a result, the discriminator has 7 layers. First block takes 3 channel image as input and outputs m channel features. After the first block, each block produces 2*m of previous block. At the end of residual blocks, a FC layer outputs a scaler.

![res](materials/residual.png)


### Important Note on Dataset
The Anime Face Dataset is a Kaggle dataset, thus, requires a Kaggle account. Using an account one can download it from:
https://www.kaggle.com/datasets/splcher/animefacedataset

### Faced Challenges

#### Architecture Challenges

As we stated, HistoGAN is built on StyleGAN2. StyleGAN2 scales directly the weigth of the model, unlike the first version does the nearly same operations, namely mod-demod, on convolved image (or say not directly weights for convoltion filters). The original StyleGAN2 implemented using Tensorflow which allows to multiplication on weights, that is called in-place operation on variables. However, Pytorch does not allow in-place operations on built in modules like torch.nn.Conv2d. Therefore, we implemented a conv2d version. Model parameters are Pytorch Variables and convolution operation is handled with fold and unfold operations of Pytorch. Doing so we can apply convolution after scaling weights of convolution filters.


#### Training Challenges

During the training phase, the paper does not mentioned how generator outputs the images. We made different assumption such as using sigmoid or tanh to generate pixels in a range. Another assumption for the same problem is using ReLU or leaky ReLU that we saw from other generator implementation.  
StyleGAN2 stated that they used non saturating loss for some datasets and WGAN-GP loss for other datasets. HistoGAN paper does not clearly mention on this. Consequently, we implemented both but non saturating loss lead numerical issues like nan or infs. On the other hand, WGAN-GP computes high loss values and results in rapid saturation (see above figures). This issues may be resulted from hand implemented convolution operations. 

#### Ambigious Loss

The paper does not state clearly which loss functions for both generator and discriminator is used. It states how one can combine with histogram loss, but, exact losses are not clear. As a result, we accounted three discriminator losses. We focused mostly non saturating losses since it was mentioned in the paper. 

We implemented gradient penalty and R1 penalty for Discriminator. R1 penalty is the gradient penalty for only scoring the real images. Seeing both penalty does not reach good performance, we trained the Discriminator with Spectral Normalization. 

#### Model Initialization  

The model initalization was unclear. StyleGAN2 initializes weights with random normal and uses equalized learning rate method, however, initalization with kaiming may also be used. Authors granted the arhitecture from StyleGAN2, yet, they do not fully copy the training details. As a consequence, they do not mentioned that they used equalized learning rate. We implemented kaiming normal implementation and equalized learning rate, however, there was no such important improvements. As a result, we stick with kaiming normal.

#### Training with Different Architectures

Since we can not reach the desired generated images, we implemented StyleGANv1 + HistoGAN Block and a slight changed version of StyleGANv1 which was mentioned in StyleGAN2 paper (please see revisede StyleGAN in StyleGAN2 paper). We implemented all versions given at the figure below.

![archs](materials/archs.png)

#### Implementation Info
- python 3.7.13
- pytorch 1.11.0 with cuda10.2   
We used conda environments for clean library setups and included environment.yml file. 

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torch_optimizer import DiffGrad
from data import AnimeFacesDataset
from model import Discriminator, HistoGAN
from loss import compute_gradient_penalty, pl_reg, gp_only_real, wgan_gp_disc_loss, wgan_gp_gen_loss
from utils import random_interpolate_hists
import os

: 

In [None]:
config = dict(
    num_epochs = 10, # number of epochs for training
    batch_size = 16, # batch size
    acc_gradient_total = 16, # total number of samples seen by the networks in 1 iteration
    r1_factor = 10, # coefficient of the r1 regularization term
    r1_update_iter = 4, # in every r1_update_iter r1 regularization is used
    decay_coeff = 0.99, # ema decay coefficient for updating the path length target varaible
    plr_update_iter = 32, # in every plr_update_iter the path length regularization is used
    save_iter = 400, # in every save_iter the images are saved
    image_res = 64, # the resolution of the images
    network_capacity = 16, # capacity of the network used for channels of constant input in generator 
    latent_dim = 64, # dimensionalty of the noises
    bin_size = 64, # bin size of the histograms
    learning_rate = 0.0002, # learning rate
    mapping_layer_num = 8, # number of Linear layers in Mapping part of the Generator (z -> w)
    mixing_prob = 0.9, # probality of using two distinct noises for generation
    use_plr = True, # Wheter to use path length reg in training
    use_r1r = True, # Wheter to use r1 reg in training
    kaiming_init=False, # Initiazlize networks with kaiming initialization method by He et al.
    use_eqlr = False, # use eqularized learning coefficients for weights (similar to kaiming but used in every forward calculation)
    use_spec_norm = False, # use spectral normalization of Discriminator weights (For stabilization)
    disc_arch= "ResBlock", # architecture of the Discriminator (used for bookkeeping)
    gen_arch = "InputModDemod", # architecture of the Generator (used for bookkeeping)
    optim="DiffGrad",  # Optimizer used (Adam or DiffGrad)
    loss_type="wasser" # Loss type to use (Wasserstein, Hinge, Log Sigmoid)
    )

In [None]:
# set global variables from config
device = "cuda" if torch.cuda.is_available() else "cpu"
real_image_dir = "images/anime_face"
image_res = config["image_res"]
transform = transforms.Compose(
        [transforms.Resize((image_res,image_res))])
batch_size = config["batch_size"]
num_epochs = config["num_epochs"]
acc_gradient_total = config["acc_gradient_total"]
acc_gradient_iter = acc_gradient_total //batch_size
r1_factor = config["r1_factor"]
r1_update_iter = config["r1_update_iter"]
decay_coeff = config["decay_coeff"]
target_scale = torch.tensor([0], requires_grad=False).to(device)
plr_factor = np.log(2)/(256**2*(np.log(256)-np.log(2)))
plr_update_iter = config["plr_update_iter"]
save_iter = config["save_iter"]
network_capacity = config["network_capacity"] 
latent_dim = config["latent_dim"]
bin_size = config["bin_size"]
learning_rate = config["learning_rate"]
mapping_layer_num = config["mapping_layer_num"]
mixing_prob = config["mixing_prob"]
num_gen_layers = int(np.log2(image_res)-1)
use_plr = config["use_plr"]
use_r1r = config["use_r1r"]
kaiming_init= config["use_r1r"]
use_eqlr = config["use_eqlr"]
use_spec_norm = config["use_spec_norm"]
loss_type = config["loss_type"]
optim = config["optim"]
pre_gen_name = config["pre_gen_name"]
pre_disc_name = config["pre_disc_name"]
log_interval = 200

In [None]:
# h to save generated images
fake_image_dir = "generated_images"
if not os.path.isdir(fake_image_dir):
    os.mkdir(fake_image_dir)
# number of residual blocks in the discriminator 
num_res_blocks = 7
# network capacity to decide the intermediate channel sizes of discrimimator and learnable constant channel size of generator 
network_capacity = 16 
# histogram's bin size
bin_size = 64
# the number of channels are decides as log2(image_res) -1 since we generate 256 res images, there are 7 channels
generator_channel_sizes = [1024, 512, 512, 512, 256, 128, 64]
learning_rate = 2e-4
# coefficient of gradient penalty
coeff_penalty = 10 # same as the StyleGAN2 paper

In [None]:
def truncation_trick(generator, latent_size, batch_size): # for saving images from the mean of the mapped w vectors
    with torch.no_grad():
        z = torch.randn((1000, latent_size)).to(device)
        w = generator.get_w_from_z(z)
        w_mean = torch.mean(w, dim=0, keepdim=True)
        fake_imgs = generator.gen_image_from_w(w_mean.repeat(batch_size,1), None) # target_hist.size(0), target_hist
    return fake_imgs

def mixing_noise(): # mixing noises for better generalization
    if torch.rand((1,)) < mixing_prob:
        ri = torch.randint(1, num_gen_layers, (1,)).item()
        z = torch.cat([torch.randn((batch_size, 1, latent_dim)).expand(-1,ri,-1), torch.randn((batch_size, 1, latent_dim)).expand(-1,num_gen_layers-ri,-1)], dim=1)
    else:
        z = torch.randn((batch_size, num_gen_layers, latent_dim))
    return z

In [None]:
def train_discriminator(generator, discriminator, disc_optim, chunk_data, batch_size, iter): # training loop of the discriminator
    hist_list = []
    disc_optim.zero_grad()
    total_disc_loss = 0
    total_real_loss = 0
    total_fake_loss = 0
    total_r1_loss = 0
    for index in range(chunk_data.size(0)//batch_size):
        batch_data = chunk_data[index*batch_size:(index+1)*batch_size]
        batch_data.requires_grad_()
        batch_data = batch_data.to(device)
        target_hist = random_interpolate_hists(batch_data)
        hist_list.append(target_hist.clone())
        z = mixing_noise().to(device) # torch.randn(batch_size, latent_dim)
        fake_data, _ = generator(z, None) #target_hist
        fake_data = fake_data.detach()
        fake_scores = discriminator(fake_data)
        real_scores = discriminator(batch_data)
        if loss_type == "hinge":
            real_loss = torch.mean(torch.nn.functional.relu(1-real_scores))    
            fake_loss = torch.mean(torch.nn.functional.relu(1+ fake_scores)) 
        elif loss_type == "softplus":
            real_loss = torch.mean(torch.nn.functional.softplus(-real_scores))/ acc_gradient_iter
            fake_loss = torch.mean(torch.nn.functional.softplus(fake_scores)) / acc_gradient_iter
        elif loss_type == "wasser":
            real_loss = -torch.mean(real_scores) / acc_gradient_iter 
            fake_loss = torch.mean(fake_scores) /  acc_gradient_iter

        disc_loss =  real_loss + fake_loss
        total_disc_loss += disc_loss.item()
        total_fake_loss += fake_loss.item()
        total_real_loss += real_loss.item()
        r1_loss = 0
        if use_r1r and iter % r1_update_iter == 0:
            r1_loss =  gp_only_real(batch_data, real_scores, r1_factor)/ acc_gradient_iter
            total_r1_loss += r1_loss.item()

        real_loss += r1_loss        
        real_loss.backward()
        fake_loss.backward()

    disc_optim.step()
    disc_optim.zero_grad()
    del disc_loss, total_disc_loss, total_r1_loss
    return hist_list

In [None]:
def train_generator(generator, discriminator, gene_optim, batch_size, iter, hist_list): # Training loop for generator
    global target_scale
    total_gene_loss = 0
    total_plr_loss = 0
    gene_optim.zero_grad()

    for target_hist in hist_list:
        z = mixing_noise().to(device) 
        fake_data, w = generator(z, target_hist)  
        disc_score = discriminator(fake_data)
        if loss_type in ["wasser", "hinge"]:
            g_loss = -torch.mean(disc_score) / acc_gradient_iter
        elif loss_type == "softplus":
            g_loss = torch.mean(torch.nn.functional.softplus(-disc_score)) / acc_gradient_iter      
        total_gene_loss += g_loss.item()
        pl_loss = 0 
        if use_plr and (iter+1) % plr_update_iter == 0:
            std = 0.1 / (w.std(dim=0, keepdim=True) + 1e-8)
            w_changed = w + torch.randn_like(w, device=device) / (std + 1e-8)
            changed_data = generator.gen_image_from_w(w_changed, target_hist)
            pl_lengths = ((changed_data - fake_data) ** 2).mean(dim=(1, 2, 3))
            avg_pl_length = torch.mean(pl_lengths).item()
            pl_loss = torch.mean(torch.square(pl_lengths - target_scale)) 
            total_plr_loss += pl_loss.item()

        g_loss += pl_loss
        g_loss.backward()

    if use_plr and (iter+1) % plr_update_iter == 0:
        target_scale = (1-decay_coeff)* target_scale + decay_coeff * avg_pl_length
   

    gene_optim.step()
    gene_optim.zero_grad()
    del g_loss, total_gene_loss, total_plr_loss

In [None]:
# Taken from https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html

import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torchvision.utils import make_grid


plt.rcParams["savefig.bbox"] = 'tight'


def show(imgs):
    imgs = imgs.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8)
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
# # Traning loop without gradient accumulation
dataset_size = len(dataset)
total_iter = 0
for epoch in range(0, num_epochs):
    for iter, chunk_data in enumerate(dataloader):
        # print(batch_data.max(), batch_data.min())
        training_percent = 100*iter*chunk_data.size(0)/dataset_size
        
        # print("Epoch",epoch, " Training %", training_percent)
        total_iter += 1
        hist_list = train_discriminator(generator, discriminator, disc_optim, chunk_data, batch_size, total_iter) # hist_list = 
        train_generator(generator, discriminator, gene_optim, batch_size, total_iter, hist_list)
        if iter % log_interval == 0:
            z = mixing_noise().to(device)
            fake_data, _ = generator(z, hist_list[0])
            grid = make_grid(fake_data, nrow=3, normalize=True)
            show(grid)
            print(training_percent)
        if (iter+1) % save_iter == 0:
            torch.save(generator.state_dict(), "models/generator_{}.pt".format(epoch))
            torch.save(discriminator.state_dict(), "models/discriminator_{}.pt".format(epoch))
        