# ViTGAN pytorch implementation

This notebook is a pytorch implementation of [VITGAN: Training GANs with Vision Transformers](https://arxiv.org/pdf/2107.04589v1.pdf)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kJJw6BYW01HgooCZ2zUDt54e1mXqITXH?usp=sharing)

The model consists of a Vision Transformer Generator and a Vision Transformer Discriminator.

It is adversarially trained to map latent vectors to images, which closely resemble the images from a given dataset. In this implementation, the dataset used is CIFAR-10.

The Generator takes latent values $z$ as input, which is integrated in a Vision Transformer Encoder. The output for each patch of the image is fed to a SIREN network, in combination with a Fourier Embedding ($E_{fou}$)

![ViTGAN Generator architecture](https://drive.google.com/uc?export=view&id=1XaCVOLq8Bvg-I3qM-bugNZcjIW5L7XTO)

This implementation separates the Generator in Vision Transformer and SIREN networks for debugging purposes.

1.   [x] Use vectorized L2 distance in attention for **Discriminator**
2.   [x] Overlapping Image Patches
2.   [x] DiffAugment
3.   [x] Self-modulated LayerNorm (SLN)
4.   [x] Implicit Neural Representation for Patch Generation
5.   [x] ExponentialMovingAverage (EMA)
6.   [x] Balanced Consistency Regularization (bCR)
7.   [x] Improved Spectral Normalization

In [None]:
! pip install einops
! pip install git+https://github.com/fadel/pytorch_ema
! pip install stylegan2-pytorch
! pip install tensorboard
! pip install wandb

In [None]:
! wandb login

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
import os

import numpy as np
import matplotlib.pyplot as plt

import time

import torchvision
import torchvision.transforms as transforms
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.utils import make_grid

from torch_ema import ExponentialMovingAverage

from stylegan2_pytorch import stylegan2_pytorch

from torch.utils.tensorboard import SummaryWriter
import wandb

Hyperparameters

In [None]:
image_size = 32
style_mlp_layers = 8
patch_size = 4
latent_dim = 512 # Size of z
hidden_size = 384
depth = 4
num_heads = 4

dropout_p = 0.
bias = True
weight_modulation = True
demodulation = False
siren_hidden_layers = 1

combine_patch_embeddings = False # Generate an image from a single SIREN, instead of patch-by-patch
combine_patch_embeddings_size = hidden_size * 4

sln_paremeter_size = hidden_size # either hidden_size or 1

batch_size = 50
device = "cuda"
out_features = 3 # The number of color channels

generator_type = "vitgan" # "vitgan", "cnn"
discriminator_type = "vitgan" # "vitgan", "cnn", "stylegan2"

lr = 7e-4 # Learning rate
lr_dis = 7e-4 # Learning rate
beta = (0., 0.99) # Adam oprimizer parameters for both the generator and the discriminator
batch_size_history_discriminator = False # Whether to use a loss, which tracks one sample from last batch_size number of batches
epochs = 400 # Number of epochs
lambda_bCR_real = 10
lambda_bCR_fake = 10
lambda_lossD_noise = 0.0
lambda_lossD_history = 0.0
lambda_diversity_penalty = 0.0

experiment_folder_name = f'lr-{lr}_\
lr_dis-{lr_dis}_\
bias-{bias}_\
demod-{demodulation}_\
sir_n_layer-{siren_hidden_layers}_\
w_mod-{weight_modulation}_\
patch_s-{patch_size}_\
st_mlp_l-{style_mlp_layers}_\
hid_size-{hidden_size}_\
comb_patch_emb-{combine_patch_embeddings}_\
sln_par_s-{sln_paremeter_size}_\
dis_type-{discriminator_type}_\
gen_type-{generator_type}_\
n_head-{num_heads}_\
depth-{depth}_\
drop_p-{dropout_p}_\
l_bCR_r-{lambda_bCR_real}_\
l_bCR_f-{lambda_bCR_fake}_\
l_D_noise-{lambda_lossD_noise}_\
l_D_his-{lambda_lossD_history}\
'
writer = SummaryWriter(log_dir=experiment_folder_name)

wandb.init(project='ViTGAN-pytorch')
config = wandb.config
config.image_size = image_size
config.bias = bias
config.demodulation = demodulation
config.siren_hidden_layers = siren_hidden_layers
config.weight_modulation = weight_modulation
config.style_mlp_layers = style_mlp_layers
config.patch_size = patch_size
config.latent_dim = latent_dim
config.hidden_size = hidden_size
config.depth = depth
config.num_heads = num_heads

config.dropout_p = dropout_p

config.combine_patch_embeddings = combine_patch_embeddings
config.combine_patch_embeddings_size = combine_patch_embeddings_size

config.sln_paremeter_size = sln_paremeter_size

config.batch_size = batch_size
config.device = device
config.out_features = out_features

config.generator_type = generator_type
config.discriminator_type = discriminator_type

config.lr = lr
config.lr_dis = lr_dis
config.beta1 = beta[0]
config.beta2 = beta[1]
config.batch_size_history_discriminator = batch_size_history_discriminator
config.epochs = epochs
config.lambda_bCR_real = lambda_bCR_real
config.lambda_bCR_fake = lambda_bCR_fake
config.lambda_lossD_noise = lambda_lossD_noise
config.lambda_lossD_history = lambda_lossD_history
config.lambda_diversity_penalty = lambda_diversity_penalty

In [None]:
if combine_patch_embeddings:
    out_patch_size = image_size
    combined_embedding_size = combine_patch_embeddings_size
else:
    out_patch_size = patch_size
    combined_embedding_size = hidden_size

siren_in_features = combined_embedding_size



https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py

In [None]:
def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.1):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
    return x


def rand_cutout(x, ratio=0.3):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0., 0., 0.), (1., 1., 1.))
    ])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

Visualize the effects of the DiffAugment

In [None]:
img = next(iter(trainloader))[0]
img = DiffAugment(img, policy='color,translation,cutout', channels_first=True)
img = img.permute(0,2,3,1)[0]
img -= img.min()
img /= img.max()
plt.imshow(img)

# Improved Spectral Normalization (ISN)

$$
\bar{W}_{ISN}(W):=\sigma(W_{init})\cdot W/\sigma(W)
$$

Reference code: https://github.com/koshian2/SNGAN

When updating the weights, normalize the weights' norm to its norm at initialization.

In [None]:
def l2normalize(v, eps=1e-4):
	return v / (v.norm() + eps)

class spectral_norm(nn.Module):
	def __init__(self, module, name='weight', power_iterations=1):
		super().__init__()
		self.module = module
		self.name = name
		self.power_iterations = power_iterations
		if not self._made_params():
			self._make_params()
		self.w_init_sigma = None
		self.w_initalized = False

	def _update_u_v(self):
		u = getattr(self.module, self.name + "_u")
		v = getattr(self.module, self.name + "_v")
		w = getattr(self.module, self.name + "_bar")

		height = w.data.shape[0]
		_w = w.view(height, -1)
		for _ in range(self.power_iterations):
			v = l2normalize(torch.matmul(_w.t(), u))
			u = l2normalize(torch.matmul(_w, v))

		sigma = u.dot((_w).mv(v))
		if not self.w_initalized:
			self.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())
			self.w_initalized = True
		setattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(device) * w / sigma.expand_as(w))

	def _made_params(self):
		try:
			getattr(self.module, self.name + "_u")
			getattr(self.module, self.name + "_v")
			getattr(self.module, self.name + "_bar")
			return True
		except AttributeError:
			return False

	def _make_params(self):
		w = getattr(self.module, self.name)

		height = w.data.shape[0]
		width = w.view(height, -1).data.shape[1]

		u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
		v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
		u.data = l2normalize(u.data)
		v.data = l2normalize(v.data)
		w_bar = Parameter(w.data)

		del self.module._parameters[self.name]
		self.module.register_parameter(self.name + "_u", u)
		self.module.register_parameter(self.name + "_v", v)
		self.module.register_parameter(self.name + "_bar", w_bar)

	def forward(self, *args):
		self._update_u_v()
		return self.module.forward(*args)

Vision Transformer reference code: \[[Blog Post](https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632)\]

Normal Attention Mechanism

$$
Attention_h(X) = softmax \bigg ( \frac{QK^T}{\sqrt{d_h}} V \bigg )
$$

Lipschitz Attention Mechanism

$$
Attention_h(X) = softmax \bigg ( \frac{d(Q,K)}{\sqrt{d_h}} V \bigg )
$$

where $d(Q,K)$ is L2-distance.

https://arxiv.org/pdf/2006.04710.pdf

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=384, num_heads=4, dropout=0, discriminator=False, **kwargs):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.discriminator = discriminator
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        if self.discriminator:
            self.qkv = spectral_norm(self.qkv)
            self.projection = spectral_norm(self.projection)
        
    def forward(self, x, mask=None):
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        if self.discriminator:
            # calculate L2-distances
            energy = torch.cdist(queries.contiguous(), keys.contiguous(), p=2)
        else:
            # sum up over the last axis
            energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len

        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

# Generator

In [None]:
class FullyConnectedLayer(nn.Module):
    def __init__(self,
        in_features,                # Number of input features.
        out_features,               # Number of output features.
        bias            = True,     # Apply additive bias before the activation function?
        activation      = 'linear', # Activation function: 'relu', 'lrelu', etc.
        lr_multiplier   = 1,        # Learning rate multiplier.
        bias_init       = 0,        # Initial value for the additive bias.
        **kwargs
    ):
        super().__init__()
        self.activation = activation
        if activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2)
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        self.weight = nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
        self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
        self.weight_gain = lr_multiplier / np.sqrt(in_features)
        self.bias_gain = lr_multiplier

    def forward(self, x):
        w = self.weight.to(x.dtype) * self.weight_gain
        b = self.bias
        if b is not None:
            b = b.to(x.dtype)
            if self.bias_gain != 1:
                b = b * self.bias_gain

        if self.activation == 'linear' and b is not None:
            # print(b.shape, x.shape, w.t().shape)
            x = torch.addmm(b.unsqueeze(0), x, w.t())
        else:
            x = x.matmul(w.t())
            if b is not None:
                x = x + b
            if self.activation != 'linear':
                x = self.activation(x)
        return x

In [None]:
def normalize_2nd_moment(x, dim=1, eps=1e-8):
    return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()

In [None]:
class MappingNetwork(nn.Module):
    def __init__(self,
        z_dim,                      # Input latent (Z) dimensionality, 0 = no latent.
        c_dim,                      # Conditioning label (C) dimensionality, 0 = no label.
        w_dim,                      # Intermediate latent (W) dimensionality.
        num_ws          = None,     # Number of intermediate latents to output, None = do not broadcast.
        num_layers      = 8,        # Number of mapping layers.
        embed_features  = None,     # Label embedding dimensionality, None = same as w_dim.
        layer_features  = None,     # Number of intermediate features in the mapping layers, None = same as w_dim.
        activation      = 'lrelu',  # Activation function: 'relu', 'lrelu', etc.
        lr_multiplier   = 0.01,     # Learning rate multiplier for the mapping layers.
        w_avg_beta      = 0.995,    # Decay for tracking the moving average of W during training, None = do not track.
        **kwargs
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.num_ws = num_ws
        self.num_layers = num_layers
        self.w_avg_beta = w_avg_beta

        if embed_features is None:
            embed_features = w_dim
        if c_dim == 0:
            embed_features = 0
        if layer_features is None:
            layer_features = w_dim
        features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]

        if c_dim > 0:
            self.embed = FullyConnectedLayer(c_dim, embed_features)
        for idx in range(num_layers):
            in_features = features_list[idx]
            out_features = features_list[idx + 1]
            layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
            setattr(self, f'fc{idx}', layer)

        if num_ws is not None and w_avg_beta is not None:
            self.register_buffer('w_avg', torch.zeros([w_dim]))

    def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
        # Embed, normalize, and concat inputs.
        x = None
        with torch.autograd.profiler.record_function('input'):
            if self.z_dim > 0:
                assert z.shape[1] == self.z_dim
                x = normalize_2nd_moment(z.to(torch.float32))
            if self.c_dim > 0:
                assert c.shape[1] == self.c_dim
                y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
                x = torch.cat([x, y], dim=1) if x is not None else y

        # Main layers.
        for idx in range(self.num_layers):
            layer = getattr(self, f'fc{idx}')
            x = layer(x)

        # Update moving average of W.
        if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
            with torch.autograd.profiler.record_function('update_w_avg'):
                self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))

        # Broadcast.
        if self.num_ws is not None:
            with torch.autograd.profiler.record_function('broadcast'):
                x = x.unsqueeze(1).repeat([1, self.num_ws, 1])

        # Apply truncation.
        if truncation_psi != 1:
            with torch.autograd.profiler.record_function('truncate'):
                assert self.w_avg_beta is not None
                if self.num_ws is None or truncation_cutoff is None:
                    x = self.w_avg.lerp(x, truncation_psi)
                else:
                    x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
        return x

In [None]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion=4, drop_p=0., bias=False):
        super().__init__(
            FullyConnectedLayer(expansion, emb_size * emb_size, activation='gelu', bias=False),
            nn.Dropout(drop_p),
            FullyConnectedLayer(expansion * emb_size, emb_size, bias=False),
        )

Self-Modulated LayerNorm
$$
SLN(h_{\ell},w)=\gamma_{\ell}(w)\odot\frac{h_{\ell}-\mu}{\sigma}+\beta_{\ell}(w)
$$

where $\gamma_{\ell}, \beta_{\ell}\in \mathbb{R}^D$ or $\gamma_{\ell}, \beta_{\ell}\in \mathbb{R}^1$

In [None]:
class SLN(nn.Module):
    def __init__(self, input_size, parameter_size=None, **kwargs):
        super().__init__()
        if parameter_size == None:
            parameter_size = input_size
        assert(input_size == parameter_size or parameter_size == 1)
        self.input_size = input_size
        self.parameter_size = parameter_size
        self.ln = nn.LayerNorm(input_size)
        self.gamma = FullyConnectedLayer(input_size, parameter_size, bias=False)
        self.beta = FullyConnectedLayer(input_size, parameter_size, bias=False)
        # self.gamma = nn.Linear(input_size, parameter_size, bias=False)
        # self.beta = nn.Linear(input_size, parameter_size, bias=False)

    def forward(self, hidden, w):
        assert(hidden.size(-1) == self.parameter_size and w.size(-1) == self.parameter_size)
        gamma = self.gamma(w).unsqueeze(1)
        beta = self.beta(w).unsqueeze(1)
        ln = self.ln(hidden)
        return gamma * ln + beta

In [None]:
class GeneratorTransformerEncoderBlock(nn.Module):
    def __init__(self,
                 hidden_size=384,
                 sln_paremeter_size=384,
                 drop_p=0.,
                 forward_expansion=4,
                 forward_drop_p=0.,
                 **kwargs):
        super().__init__()
        self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
        self.msa = MultiHeadAttention(hidden_size, **kwargs)
        self.dropout = nn.Dropout(drop_p)
        self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, drop_p=forward_drop_p)

    def forward(self, hidden, w):
        res = hidden
        hidden = self.sln(hidden, w)
        hidden = self.msa(hidden)
        hidden = self.dropout(hidden)
        hidden += res

        res = hidden
        hidden = self.sln(hidden, w)
        self.feed_forward(hidden)
        hidden = self.dropout(hidden)
        hidden += res
        return hidden

In [None]:
class GeneratorTransformerEncoder(nn.Module):
    def __init__(self, depth=4, **kwargs):
        super().__init__()
        self.depth = depth
        self.blocks = nn.ModuleList([GeneratorTransformerEncoderBlock(**kwargs) for _ in range(depth)])
    
    def forward(self, hidden, w):
        for i in range(self.depth):
            hidden = self.blocks[i](hidden, w)
        return hidden

# SIREN

Code for SIREN is taken from [SIREN reference colab notebook](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)

$$
w^{'}_{ijk}=s_i\cdot w_{ijk}
$$

$$
w^{''}_{ijk}=\frac{w^{'}_{ijk}}{\sqrt{\sum_{i,k}{w^{'}_{ijk}}^2+\epsilon}}
$$

In [None]:
class ModulatedLinear(nn.Module):
    def __init__(self, in_channels, out_channels, style_size, bias=False, demodulation=True, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.style_size = style_size
        self.scale = 1 / np.sqrt(in_channels)
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, 1)
        )
        self.modulation = None
        if self.style_size != self.in_channels:
            self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False)
        self.demodulation = demodulation

    def forward(self, input, style):
        batch_size = input.shape[0]

        if self.style_size != self.in_channels:
            style = self.modulation(style)
        style = style.view(batch_size, 1, self.in_channels, 1)
        # print('self.scale, self.weight.shape, style.shape', self.scale, self.weight.shape, style.shape)
        weight = self.scale * self.weight * style

        if self.demodulation:
            demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)
            weight = weight * demod.view(batch_size, self.out_channels, 1, 1)

        weight = weight.view(
            batch_size * self.out_channels, self.in_channels, 1
        )
        
        img_size = input.size(1)
        input = input.reshape(1, batch_size * self.in_channels, img_size)
        out = F.conv1d(input, weight, groups=batch_size)
        out = out.view(batch_size, img_size, self.out_channels)

        return out

In [None]:
class ResLinear(nn.Module):
    def __init__(self, in_channels, out_channels, style_size, bias=False, **kwargs):
        super().__init__()
        self.linear = FullyConnectedLayer(in_channels, out_channels, bias=False)
        self.style = FullyConnectedLayer(style_size, in_channels, bias=False)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.style_size = style_size
        # print('style_size, in_channels, out_channels', style_size, in_channels, out_channels)

    def forward(self, input, style):
        x = input + self.style(style).unsqueeze(1)
        x = self.linear(x)
        return x

In [None]:
class ConLinear(nn.Module):
    def __init__(self, ch_in, ch_out, is_first=False, bias=True, **kwargs):
        super(ConLinear, self).__init__()
        self.conv = nn.Linear(ch_in, ch_out, bias=bias)
        if is_first:
            nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))
        else:
            nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in))

    def forward(self, x):
        return self.conv(x)

class SinActivation(nn.Module):
    def __init__(self):
        super(SinActivation, self).__init__()

    def forward(self, x):
        return torch.sin(x)

class LFF(nn.Module):
    def __init__(self, hidden_size, **kwargs):
        super(LFF, self).__init__()
        self.ffm = ConLinear(2, hidden_size, is_first=True)
        self.activation = SinActivation()

    def forward(self, x):
        x = x
        x = self.ffm(x)
        x = self.activation(x)
        return x

In [None]:
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, style_size, bias=False,
                 is_first=False, omega_0=30, weight_modulation=True, **kwargs):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.weight_modulation = weight_modulation
        if weight_modulation:
            self.linear = ModulatedLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)
        else:
            self.linear = ResLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)
        # print('in_features, out_features, style_size', in_features, out_features, style_size)
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                if self.weight_modulation:
                    self.linear.weight.uniform_(-1 / self.in_features, 
                                                1 / self.in_features)
                else:
                    self.linear.linear.weight.uniform_(-1 / self.in_features, 
                                                        1 / self.in_features) 
            else:
                if self.weight_modulation:
                    self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                                np.sqrt(6 / self.in_features) / self.omega_0)
                else:
                    self.linear.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                                        np.sqrt(6 / self.in_features) / self.omega_0) 
        
    def forward(self, input, style):
        return torch.sin(self.omega_0 * self.linear(input, style))
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_size, hidden_layers, out_features, style_size, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30., weight_modulation=True, bias=False, **kwargs):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_size, style_size,
                                  is_first=True, omega_0=first_omega_0,
                                  weight_modulation=weight_modulation, **kwargs))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_size, hidden_size, style_size,
                                      is_first=False, omega_0=hidden_omega_0,
                                      weight_modulation=weight_modulation, **kwargs))

        if outermost_linear:
            if weight_modulation:
                final_linear = ModulatedLinear(hidden_size, out_features,
                                               style_size=style_size, bias=bias, **kwargs)
            else:
                final_linear = ResLinear(hidden_size, out_features, style_size=style_size, bias=bias, **kwargs)
            # FullyConnectedLayer(hidden_size, out_features, bias=False)
            # final_linear = nn.Linear(hidden_size, out_features)
            
            with torch.no_grad():
                if weight_modulation:
                    final_linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, 
                                                np.sqrt(6 / hidden_size) / hidden_omega_0)
                else:
                    final_linear.linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, 
                                                np.sqrt(6 / hidden_size) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_size, out_features, 
                                      is_first=False, omega_0=hidden_omega_0,
                                      weight_modulation=weight_modulation, **kwargs))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords, style):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        # output = self.net(coords, style)
        output = coords
        for layer in self.net:
            output = layer(output, style)
        return output

$$
Fou(\mathbf{v})= \left[ \cos(2 \pi \mathbf B \mathbf{v}), \sin(2 \pi \mathbf B \mathbf{v}) \right]^\mathrm{T}
$$

In [None]:
class GeneratorViT(nn.Module):
    def __init__(self,
                style_mlp_layers=8,
                patch_size=4,
                latent_dim=32,
                hidden_size=384,
                sln_paremeter_size=1,
                image_size=32,
                depth=4,
                combine_patch_embeddings=False,
                combined_embedding_size=1024,
                forward_drop_p=0.,
                bias=False,
                out_features=3,
                weight_modulation=True,
                siren_hidden_layers=1,
                **kwargs):
        super().__init__()
        self.hidden_size = hidden_size

        self.mlp = MappingNetwork(z_dim=latent_dim, c_dim=0, w_dim=hidden_size, num_layers=style_mlp_layers, w_avg_beta=None)

        num_patches = int(image_size//patch_size)**2
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.image_size = image_size
        self.combine_patch_embeddings = combine_patch_embeddings
        self.combined_embedding_size = combined_embedding_size

        self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))
        self.transformer_encoder = GeneratorTransformerEncoder(depth,
                                                               hidden_size=hidden_size,
                                                               sln_paremeter_size=sln_paremeter_size,
                                                               drop_p=forward_drop_p,
                                                               forward_drop_p=forward_drop_p,
                                                               **kwargs)
        self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
        if combine_patch_embeddings:
            self.to_single_emb = nn.Sequential(
                FullyConnectedLayer(num_patches*hidden_size, combined_embedding_size, bias=bias, activation='gelu'),
                nn.Dropout(forward_drop_p),
            )

        self.lff = LFF(self.hidden_size)

        self.siren_in_features = combined_embedding_size if combine_patch_embeddings else self.hidden_size
        self.siren = Siren(in_features=self.siren_in_features, out_features=out_features,
                           style_size=self.siren_in_features, hidden_size=self.hidden_size, bias=bias,
                           hidden_layers=siren_hidden_layers, outermost_linear=True, weight_modulation=weight_modulation, **kwargs)

        self.num_patches_x = int(image_size//out_patch_size)


    def fourier_input_mapping(self, x):
        return self.lff(x)

    def fourier_pos_embedding(self):
        # Create input pixel coordinates in the unit square
        coords = np.linspace(-1, 1, out_patch_size, endpoint=True)
        pos = np.stack(np.meshgrid(coords, coords), -1)
        pos = torch.tensor(pos, dtype=torch.float).to(device)
        result = self.fourier_input_mapping(pos).reshape([out_patch_size**2, self.hidden_size])
        return result.to(device)

    def mix_hidden_and_pos(self, hidden):
        pos = self.fourier_pos_embedding()

        pos = repeat(pos, 'p h -> n p h', n = hidden.shape[0])
        result = pos

        return result

    def forward(self, z):
        w = self.mlp(z)
        pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])
        hidden = self.transformer_encoder(pos, w)

        if self.combine_patch_embeddings:
            # Output [batch_size, combined_embedding_size]
            hidden = self.sln(hidden, w).view((z.shape[0], -1))
            hidden = self.to_single_emb(hidden)
        else:
            # Output [batch_size*num_patches, hidden_size]
            hidden = self.sln(hidden, w).view((-1, self.hidden_size))
        
        pos = self.mix_hidden_and_pos(hidden)

        # hidden = repeat(hidden, 'n h -> n p h', p = out_patch_size**2)

        result = self.siren(pos, hidden)

        model_output_1 = result.view([-1, self.num_patches_x, self.num_patches_x, out_patch_size, out_patch_size, out_features])
        model_output_2 = model_output_1.permute([0, 1, 3, 2, 4, 5])
        model_output = model_output_2.reshape([-1, image_size**2, out_features])
        
        return model_output


Generator = GeneratorViT(   patch_size=patch_size,
                            image_size=image_size,
                            style_mlp_layers=style_mlp_layers,
                            latent_dim=latent_dim,
                            hidden_size=hidden_size,
                            combine_patch_embeddings=combine_patch_embeddings,
                            combined_embedding_size=combined_embedding_size,
                            sln_paremeter_size=sln_paremeter_size,
                            num_heads=num_heads,
                            depth=depth,
                            forward_drop_p=dropout_p,
                            bias=bias,
                            weight_modulation=weight_modulation,
                            siren_hidden_layers=siren_hidden_layers,
                            demodulation=demodulation,
                        ).to(device)
print(Generator(torch.randn([batch_size, latent_dim]).to(device)).shape)
# print(Generator)
del Generator

# CNN Generator

In [None]:
class CNNGenerator(nn.Module):
    def __init__(self):
        super(CNNGenerator, self).__init__()
        self.w = nn.Linear(latent_dim, hidden_size * 2 * 4 * 4, bias=False)
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( hidden_size, hidden_size // 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size // 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( hidden_size // 2, hidden_size // 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size // 4),
            nn.ReLU(True),
            # state size. (ngf*2) x 32 x 32
            nn.ConvTranspose2d( hidden_size // 4, 3, 3, 1, 1, bias=False),
            nn.Tanh(),
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        input = self.w(input).view((-1, hidden_size * 2, 4, 4))
        return self.main(input)

# Discriminator

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, stride_size=4, emb_size=384, image_size=32, batch_size=64):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            spectral_norm(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)).to(device),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        num_patches = ((image_size-patch_size+stride_size) // stride_size) **2 + 1
        self.positions = nn.Parameter(torch.randn(num_patches, emb_size))
        self.batch_size = batch_size

    def forward(self, x):
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += torch.sin(self.positions)
        return x

In [None]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

In [None]:
class DiscriminatorTransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size=384,
                 drop_p=0.,
                 forward_expansion=4,
                 forward_drop_p=0.,
                 **kwargs):
        super().__init__(
                ResidualAdd(nn.Sequential(
                    nn.LayerNorm(emb_size),
                    MultiHeadAttention(emb_size, **kwargs),
                    nn.Dropout(drop_p)
                )),
                ResidualAdd(nn.Sequential(
                    nn.LayerNorm(emb_size),
                    nn.Sequential(
                        spectral_norm(nn.Linear(emb_size, forward_expansion * emb_size)),
                        nn.GELU(),
                        nn.Dropout(forward_drop_p),
                        spectral_norm(nn.Linear(forward_expansion * emb_size, emb_size)),
                    ),
                    nn.Dropout(drop_p)
                )
            ))

In [None]:
class DiscriminatorTransformerEncoder(nn.Sequential):
    def __init__(self, depth=4, **kwargs):
        super().__init__(*[DiscriminatorTransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size=384, class_size_1=4098, class_size_2=1024, class_size_3=512, n_classes=10):
        super().__init__(
            nn.LayerNorm(emb_size),
            spectral_norm(nn.Linear(emb_size, class_size_1)),
            nn.GELU(),
            spectral_norm(nn.Linear(class_size_1, class_size_2)),
            nn.GELU(),
            spectral_norm(nn.Linear(class_size_2, class_size_3)),
            nn.GELU(),
            spectral_norm(nn.Linear(class_size_3, n_classes)),
            nn.GELU(),
        )

    def forward(self, x):
        # Take only the cls token outputs
        x = x[:, 0, :]
        return super().forward(x)

In [None]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels=3,
                patch_size=4,
                stride_size=4,
                emb_size=384,
                image_size=32,
                depth=4,
                n_classes=1,
                diffaugment='color,translation,cutout',
                **kwargs):
        self.diffaugment = diffaugment
        super().__init__(
            PatchEmbedding(in_channels, patch_size, stride_size, emb_size, image_size),
            DiscriminatorTransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes=n_classes)
        )
    
    def forward(self, img, do_augment=True):
        if do_augment:
            img = DiffAugment(img, policy=self.diffaugment)
        return super().forward(img)

# CNN Discriminator

In [None]:
class CNN(nn.Sequential):
    def __init__(self,
                diffaugment='color,translation,cutout',
                **kwargs):
        self.diffaugment = diffaugment
        super().__init__(
            nn.Conv2d(3,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Flatten(),
            nn.Linear(256*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,1)
        )
    
    def forward(self, img, do_augment=True):
        if do_augment:
            img = DiffAugment(img, policy=self.diffaugment)
        return super().forward(img)

# StyleGAN2 Discriminator

In [None]:
class StyleGanDiscriminator(stylegan2_pytorch.Discriminator):
    def __init__(self,
                diffaugment='color,translation,cutout',
                **kwargs):
        self.diffaugment = diffaugment
        super().__init__(**kwargs)
    def forward(self, img, do_augment=True):
        if do_augment:
            img = DiffAugment(img, policy=self.diffaugment)
        out, _ = super().forward(img)
        return out

# Diversity Loss

In [None]:
def diversity_loss(images):
    num_images_to_calculate_on = 10
    num_pairs = num_images_to_calculate_on * (num_images_to_calculate_on - 1) // 2

    scale_factor = 5

    loss = torch.zeros(1, dtype=torch.float, device=device, requires_grad=True)
    i = 0
    for a_id in range(num_images_to_calculate_on):
        for b_id in range(a_id+1, num_images_to_calculate_on):
            img_a = images[a_id]
            img_b = images[b_id]
            img_a_l2 = torch.norm(img_a)
            img_b_l2 = torch.norm(img_b)
            img_a, img_b = img_a.flatten(), img_b.flatten()

            # print(img_a_l2, img_b_l2, img_a.shape, img_b.shape)

            a_b_loss = scale_factor * (img_a.t() @ img_b) / (img_a_l2 * img_b_l2)
            # print(a_b_loss)
            loss = loss + torch.sigmoid(a_b_loss)
            i += 1
    loss = loss.sum() / num_pairs
    return loss

# Normal distribution init weight

In [None]:
def init_normal(m):
    if type(m) == nn.Linear:
        if 'weight' in m.__dict__.keys():
            m.weight.data.normal_(0.0,1)

# Experiments

In [None]:
if generator_type == "vitgan":
    # Create the Generator
    Generator = GeneratorViT(   patch_size=patch_size,
                                image_size=image_size,
                                style_mlp_layers=style_mlp_layers,
                                latent_dim=latent_dim,
                                hidden_size=hidden_size,
                                combine_patch_embeddings=combine_patch_embeddings,
                                combined_embedding_size=combined_embedding_size,
                                sln_paremeter_size=sln_paremeter_size,
                                num_heads=num_heads,
                                depth=depth,
                                forward_drop_p=dropout_p,
                                bias=bias,
                                weight_modulation=weight_modulation,
                                siren_hidden_layers=siren_hidden_layers,
                                demodulation=demodulation,
                            ).to(device)
                            
    # use the modules apply function to recursively apply the initialization
    Generator.apply(init_normal)

    num_patches_x = int(image_size//out_patch_size)

    if os.path.exists(f'{experiment_folder_name}/weights/Generator.pth'):
        Generator = torch.load(f'{experiment_folder_name}/weights/Generator.pth')

    wandb.watch(Generator)

elif generator_type == "cnn":
    cnn_generator = CNNGenerator().to(device)

    cnn_generator.apply(init_normal)

    if os.path.exists(f'{experiment_folder_name}/weights/cnn_generator.pth'):
        cnn_generator = torch.load(f'{experiment_folder_name}/weights/cnn_generator.pth')

    wandb.watch(cnn_generator)

# Create the three types of discriminators
if discriminator_type == "vitgan":
    Discriminator = ViT(discriminator=True,
                            patch_size=patch_size*2,
                            stride_size=patch_size,
                            n_classes=1, 
                            num_heads=num_heads,
                            depth=depth,
                            forward_drop_p=dropout_p,
                    ).to(device)
            
    Discriminator.apply(init_normal)
    
    if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):
        Discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')

    wandb.watch(Discriminator)

elif discriminator_type == "cnn":
    cnn_discriminator = CNN().to(device)

    cnn_discriminator.apply(init_normal)

    if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):
        cnn_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')

    wandb.watch(cnn_discriminator)

elif discriminator_type == "stylegan2":
    stylegan2_discriminator = StyleGanDiscriminator(image_size=32).to(device)

    # stylegan2_discriminator.apply(init_normal)

    if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):
        stylegan2_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')

    wandb.watch(stylegan2_discriminator)

# Testing the generator

Train to match fixed latent values to fixed images

In [None]:
total_steps = 0 # Since the whole image is our dataset, this just means 500 gradient descent steps.
steps_til_summary = 50

if generator_type == "vitgan":
    params = Generator.parameters()
else:
    # z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
    params = list(cnn_generator.parameters())
optim = torch.optim.Adam(lr=lr, params=params)
ema = ExponentialMovingAverage(params, decay=0.995)

ground_truth, _ = next(iter(trainloader))
ground_truth = ground_truth.permute(0, 2, 3, 1).view((-1, image_size**2, out_features))
ground_truth = ground_truth.to(device)

z = torch.randn([batch_size, latent_dim]).to(device)

In [None]:
for step in range(total_steps):
    if generator_type == "vitgan":
        model_output = Generator(z)
    elif generator_type == "cnn":
        model_output = cnn_generator(z)
        model_output = model_output.permute([0, 2, 3, 1]).view([-1, image_size**2, out_features])
    loss = ((model_output - ground_truth)**2).mean()
    
    if not step % steps_til_summary:
        print("Step %d, Total loss %0.6f" % (step, loss))

        fig, axes = plt.subplots(2,8, figsize=(24,6))
        for i in range(8):
            j = np.random.randint(0, batch_size-1)
            img = model_output[j].cpu().view(32,32,3).detach().numpy()
            img -= img.min()
            img /= img.max()
            axes[0,i].imshow(img)
            g_img = ground_truth[j].cpu().view(32,32,3).detach().numpy()
            g_img -= g_img.min()
            g_img /= g_img.max()
            axes[1,i].imshow(g_img)

        plt.show()

    optim.zero_grad()
    loss.backward()
    optim.step()
    ema.update()

# Training

In [None]:
os.makedirs(f"{experiment_folder_name}/weights", exist_ok = True)
os.makedirs(f"{experiment_folder_name}/samples", exist_ok = True)

# Loss function
criterion = nn.BCEWithLogitsLoss()

if discriminator_type == "cnn": discriminator = cnn_discriminator
elif discriminator_type == "stylegan2": discriminator = stylegan2_discriminator
elif discriminator_type == "vitgan": discriminator = Discriminator

if generator_type == "cnn":
    params = cnn_generator.parameters()
else:
    params = Generator.parameters()
optim_g = torch.optim.Adam(lr=lr, params=params, betas=beta)
optim_d = torch.optim.Adam(lr=lr_dis, params=discriminator.parameters(), betas=beta)
ema = ExponentialMovingAverage(params, decay=0.995)

fixed_noise = torch.FloatTensor(np.random.normal(0, 1, (16, latent_dim))).to(device)

discriminator_f_img = torch.zeros([batch_size, 3, image_size, image_size]).to(device)

trainset_len = len(trainloader.dataset)

step = 0
for epoch in range(epochs):
    for batch_id, batch in enumerate(trainloader):
        step += 1

        # Train discriminator

        # Forward + Backward with real images
        r_img = batch[0].to(device)
        r_logit = discriminator(r_img).flatten()
        r_label = torch.ones(r_logit.shape[0]).to(device)

        lossD_real = criterion(r_logit, r_label)
        
        lossD_bCR_real = F.mse_loss(r_logit, discriminator(r_img, do_augment=False))

        # Forward + Backward with fake images
        latent_vector = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))).to(device)

        if generator_type == "vitgan":
            f_img = Generator(latent_vector)
            f_img = f_img.reshape([-1, image_size, image_size, out_features])
            f_img = f_img.permute(0, 3, 1, 2)
        else:
            model_output = cnn_generator(latent_vector)
            f_img = model_output
            
        assert f_img.size(0) == batch_size, f_img.shape
        assert f_img.size(1) == out_features, f_img.shape
        assert f_img.size(2) == image_size, f_img.shape
        assert f_img.size(3) == image_size, f_img.shape

        f_label = torch.zeros(batch_size).to(device)
        # Save the a single generated image to the discriminator training data
        if batch_size_history_discriminator:
            discriminator_f_img[step % batch_size] = f_img[0].detach()
            f_logit_history = discriminator(discriminator_f_img).flatten()
            lossD_fake_history = criterion(f_logit_history, f_label)
        else: lossD_fake_history = 0
        # Train the discriminator on the images, generated only from this batch
        f_logit = discriminator(f_img.detach()).flatten()
        lossD_fake = criterion(f_logit, f_label)
        
        lossD_bCR_fake = F.mse_loss(f_logit, discriminator(f_img, do_augment=False))
        
        f_noise_input = torch.FloatTensor(np.random.rand(*f_img.shape)*2 - 1).to(device)
        f_noise_logit = discriminator(f_noise_input).flatten()
        lossD_noise = criterion(f_noise_logit, f_label)

        lossD = lossD_real * 0.5 +\
                lossD_fake * 0.5 +\
                lossD_fake_history * lambda_lossD_history +\
                lossD_noise * lambda_lossD_noise +\
                lossD_bCR_real * lambda_bCR_real +\
                lossD_bCR_fake * lambda_bCR_fake

        optim_d.zero_grad()
        lossD.backward()
        optim_d.step()
        
        # Train Generator

        if generator_type == "vitgan":
            f_img = Generator(latent_vector)
            f_img = f_img.reshape([-1, image_size, image_size, out_features])
            f_img = f_img.permute(0, 3, 1, 2)
        else:
            model_output = cnn_generator(latent_vector)
            f_img = model_output
        
        assert f_img.size(0) == batch_size
        assert f_img.size(1) == out_features
        assert f_img.size(2) == image_size
        assert f_img.size(3) == image_size

        f_logit = discriminator(f_img).flatten()
        r_label = torch.ones(batch_size).to(device)
        lossG_main = criterion(f_logit, r_label)
        
        lossG_diversity = diversity_loss(f_img) * lambda_diversity_penalty
        lossG = lossG_main + lossG_diversity
        
        optim_g.zero_grad()
        lossG.backward()
        optim_g.step()
        ema.update()

        writer.add_scalar("Loss/Generator", lossG_main, step)
        writer.add_scalar("Loss/Gen(diversity)", lossG_diversity, step)
        writer.add_scalar("Loss/Dis(real)", lossD_real, step)
        writer.add_scalar("Loss/Dis(fake)", lossD_fake, step)
        writer.add_scalar("Loss/Dis(fake_history)", lossD_fake_history, step)
        writer.add_scalar("Loss/Dis(noise)", lossD_noise, step)
        writer.add_scalar("Loss/Dis(bCR_fake)", lossD_bCR_fake * lambda_bCR_fake, step)
        writer.add_scalar("Loss/Dis(bCR_real)", lossD_bCR_real * lambda_bCR_real, step)
        writer.flush()

        wandb.log({
            'Generator': lossG_main,
            'Gen(diversity)': lossG_diversity,
            'Dis(real)': lossD_real,
            'Dis(fake)': lossD_fake,
            'Dis(fake_history)': lossD_fake_history,
            'Dis(noise)': lossD_noise,
            'Dis(bCR_fake)': lossD_bCR_fake * lambda_bCR_fake,
            'Dis(bCR_real)': lossD_bCR_real * lambda_bCR_real
        })

        if batch_id%20 == 0:
            print(f'epoch {epoch}/{epochs}; batch {batch_id}/{int(trainset_len/batch_size)}')
            print(f'Generator: {"{:.3f}".format(float(lossG_main))}, '+\
                  f'Gen(diversity): {"{:.3f}".format(float(lossG_diversity))}, '+\
                  f'Dis(real): {"{:.3f}".format(float(lossD_real))}, '+\
                  f'Dis(fake): {"{:.3f}".format(float(lossD_fake))}, '+\
                  f'Dis(fake_history): {"{:.3f}".format(float(lossD_fake_history))}, '+\
                  f'Dis(noise) {"{:.3f}".format(float(lossD_noise))}, '+\
                  f'Dis(bCR_fake): {"{:.3f}".format(float(lossD_bCR_fake * lambda_bCR_fake))}, '+\
                  f'Dis(bCR_real): {"{:.3f}".format(float(lossD_bCR_real * lambda_bCR_real))}')

            # Plot 8 randomly selected samples
            fig, axes = plt.subplots(1,8, figsize=(24,3))
            output = f_img.permute(0, 2, 3, 1)
            for i in range(8):
                j = np.random.randint(0, batch_size-1)
                img = output[j].cpu().view(32,32,3).detach().numpy()
                img -= img.min()
                img /= img.max()
                axes[i].imshow(img)
            plt.show()

    # if step % sample_interval == 0:
    if generator_type == "vitgan":
        Generator.eval()
        vis = Generator(fixed_noise)
        vis = vis.reshape([-1, image_size, image_size, out_features])
        vis = vis.permute(0, 3, 1, 2)
    else:
        model_output = cnn_generator(fixed_noise)
        vis = model_output

    assert vis.shape[0] == fixed_noise.shape[0], f'vis.shape[0] is {vis.shape[0]}, but should be {fixed_noise.shape[0]}'
    assert vis.shape[1] == out_features, f'vis.shape[1] is {vis.shape[1]}, but should be {out_features}'
    assert vis.shape[2] == image_size, f'vis.shape[2] is {vis.shape[2]}, but should be {image_size}'
    assert vis.shape[3] == image_size, f'vis.shape[3] is {vis.shape[3]}, but should be {image_size}'
    
    vis.detach().cpu()
    vis = make_grid(vis, nrow = 4, padding = 5, normalize = True)
    writer.add_image(f'Generated/epoch_{epoch}', vis)
    wandb.log({'examples': wandb.Image(vis)})

    vis = T.ToPILImage()(vis)
    vis.save(f'{experiment_folder_name}/samples/vis{epoch}.jpg')
    if generator_type == "vitgan":
        Generator.train()
    else:
        cnn_generator.train()
    print(f"Save sample to {experiment_folder_name}/samples/vis{epoch}.jpg")

    # Save the checkpoints.
    if generator_type == "vitgan":
        torch.save(Generator, f'{experiment_folder_name}/weights/Generator.pth')
    elif generator_type == "cnn":
        torch.save(cnn_generator, f'{experiment_folder_name}/weights/cnn_generator.pth')
    torch.save(discriminator, f'{experiment_folder_name}/weights/discriminator.pth')
    print("Save model state.")

writer.close()