<h1><b>BicycleGAN</b><i> (Implementation in pytorch)</i></h1>

> <h2>Multimodal Image-to-Image Translation</h2>


<br>

![BicycleGAN](img/bicyclegan.png)

<h2>Introduction</h2>

<br>

Deep learning techniques have made rapid progress in conditional image generation. However, most techniques in this space have focused on generating a single result. Our aim is to generate a distribution of output images given an input image.<br><br>
Mapping from a high-dimensional input to a high-dimensional output distribution is challenging. A common approach to representing multimodality is learning a low-dimensional latent code, which should represent aspects of the possible outputs not contained in the input image.  At inference time,
a deterministic generator uses the input image, along with stochastically sampled latent codes, to produce randomly sampled outputs.


<hr/>

<br>
<h2>Why BicycleGAN?</h2>

<br>

A common problem in existing methods is mode collapse, where only a small number of real samples get represented in the output.<br>

<h3>Mode Collapse</h3>

<br>

Real life data distribution are multimodal. For example, <b>MNIST</b> dataset has 10 major modes from 0 to 9. When mode collapses, very few modes are generated.
You can simply understand it as lack of variety. However complete collapse doesn't occur often whereas partial collapse is common. Given figure explains it all. 
<br>

![Mode Collapse](img/mode_collapse.png)

Top row produces all the 10 modes of Mnist whereas bottom row produces only single mode (digit '6' ).<br><br>
BicycleGan proposes a bijection between the output and latent space.<br>
Not only the direct task of mapping the latent code (along with the input) to the output is performed but also jointly we learn an encoder from the output back to the latent space. This discourages two different latent codes from generating the same output (non-injective mapping) i.e. preventing <b>mode collapse</b>

<hr/>

<br>
<h2>What BicycleGAN does?</h2>

<br>

Goal is to learn a multi-modal mapping between two image domains, for example, edges and photographs, or night and day images, etc.
Consider the input domain  $A \subset R^{H \times W\times 3}$ , which is to be mapped to an output domain $B \subset R^{H \times W\times 3}$ . ( For example, consider A as edges and consider B as photographs made using those edges )<br><br>
We are given a dataset of paired instances from these domains, $(A \in A, B \in B)$ which is representative of a joint distribution $p(A, B)$ . It is important to note that there could be multiple plausible paired instances $B$ that would correspond to an input instance $A$ , but the training dataset usually contains only one such pair. However, given a new instance A during test time, our model should be able to generate a diverse set of output $\hat{B}$ 's, corresponding to different modes in the distribution $p(B|A)$ .<br><br>
We would like to learn the mapping that could sample the output $\hat{B}$  
from true conditional distribution given $A$ , and produce results which are both diverse and realistic.<br><br>
To do so, we learn a low-dimensional latent space $z \in R^{z}$, which encapsulates the ambiguous aspects of the
output mode which are not present in the input image. For example, a sketch of a shoe could map to a variety of colors and textures, which could get compressed in this latent code. We then learn a deterministic mapping $G : (A, z) \Rightarrow B$ to the output. To enable stochastic sampling, we desire the latent code vector $z$ to be drawn from some prior distribution $p(z)$ ; we use a standard Gaussian distribution $N (0, I)$ in this work.
<br>

<hr/>

<br>
<h2>Our model consists of 2 parts</h2>

<br>

>  <h2>Conditional Variational Autoencoder GAN: cVAE-GAN (1st part of model)
 $(B \Rightarrow z \Rightarrow \hat{B}$) </h2>

<br>

![cVAE-GAN](img/cvae.png)
<br>


*  The ground truth B is directly mapped with latent code(z) using an encoder E.
*  The generator G then uses both the latent code and the input image A to synthesize the desired output $\hat{B}$.
*  The overall model can be easily understood as the reconstruction of B, with latent encoding z concatenated with the paired A in the middle, similar to an autoencoder.
* The distribution Q(z|B) of latent code z (output of the encoder E) is dealt with a Gaussian assumption, $Q(\mathrm{z}|\mathrm{B})=E(\mathrm{B})$.<br><br>


<hr/>


> <h3><b>cVAE-GAN objective, a conditional version of the VAE-GAN</b></h3>

<br>

$$G^{*},\displaystyle \ E^{*}=\arg\min_{G,E}\max_{D}\ \mathcal{L}_{\mathrm{G}\mathrm{A}\mathrm{N}}^{\mathrm{V}\mathrm{A}\mathrm{E}}(G,\ D,\ E)+\lambda \mathcal{L}{1^{\mathrm{V}\mathrm{A}\mathrm{E}}}(G,\ E)+\lambda_{\mathrm{K}\mathrm{L}}\mathcal{L}_{\mathrm{K}\mathrm{L}}(E)$$

<h3>where</h3> $$\mathcal{L}_{\mathrm{G}\mathrm{A}\mathrm{N}}^{\mathrm{V}\mathrm{A}\mathrm{E}}=\mathrm{E}_{\mathrm{A},\mathrm{B}\sim p(\mathrm{A},\mathrm{B})}[\log(D(\mathrm{A},\ \mathrm{B}))]+\mathrm{E}_{\mathrm{A},\mathrm{B}\sim p(\mathrm{A},\mathrm{B}),\mathrm{z}\sim E(\mathrm{B})}[\log(1-D(\mathrm{A},\ G(\mathrm{A},\ \mathrm{z})))]$$

<br>
This is the typical loss function of GAN where Generator and Discriminator play a min-max game. Here Generator tries to fool the Discriminator whereas Discriminator tries to distinct the images generated by the Generator from the original ones.
<br><br>

$$\mathcal{L}_{1}^{\mathrm{V}\mathrm{A}\mathrm{E}}(G)= \mathrm{E}_{\mathrm{A},\mathrm{B}\sim p(\mathrm{A},\mathrm{B}),\mathrm{z}\sim E(\mathrm{B})}||\mathrm{B}-G(\mathrm{A},\ \mathrm{z})||_{1}$$

<br>
To encourage the output of the generator to match the input as well as stabilize the training, we use an  $\ell_{1}$ loss between the output and the ground truth image.
<br><br><br>

$$\mathcal{L}_{\mathrm{K}\mathrm{L}}(E)=\mathrm{E}_{\mathrm{B}\sim p(\mathrm{B})}[\mathcal{D}_{\mathrm{K}\mathrm{L}}(E(\mathrm{B})||\mathcal{N}(0,\ I))]$$

<br>
The latent distribution encoded by $E(B)$ is encouraged to be close to a random Gaussian to enable sampling at inference time, when $\mathrm{B}$ is not known.<br> <b>Here</b> $$\displaystyle \mathcal{D}_{\mathrm{K}\mathrm{L}}(p||q)=-\int p(z)\log\frac{p(z)}{q(z)}dz$$


<hr/>


<br><br>
Consider the deterministic version of this approach, i.e., dropping KLdivergence and encoding z = E(B). It is called cAE-GAN .
There is no guarantee in cAE-GAN on the distribution of the latent space z, which makes the test-time
sampling of z difficult.

<br><br>

<hr/>

<h2>Conditional Latent Regressor GAN: cLR-GAN  (2nd part of the model)</h2> 

$({z} \Rightarrow \hat{B} \Rightarrow \hat{z})$ 

<br>

![cLR-GAN](img/clr.jpg)

<br>

A randomly drawn latent code z is recovered with $\hat{\mathrm{z}}=E(G(\mathrm{A},\ \mathrm{z}))$
<br><br>
Encoder E here is producing a point estimate for $\hat{\mathrm{z}}$, whereas the encoder in the previous section was predicting a Gaussian distribution.
<br><br>
<h3><b>cLR-GAN objective function</b></h3>
<br>
$G^{*},\displaystyle \ E^{*}=\arg\min_{G,E}\max_{D}\ \mathcal{L}_{\mathrm{G}\mathrm{A}\mathrm{N}}(G,\ D)+\lambda_{\mathrm{l}\mathrm{a}\mathrm{t}\mathrm{e}\mathrm{n}\mathrm{t}}\mathcal{L}_{1}^{\mathrm{l}\mathrm{a}\mathrm{t}\mathrm{e}\mathrm{n}\mathrm{t}}(G,\ E)$

<br>
<h3>where</h3>
$$\mathcal{L}_{1}^{\mathrm{l}\mathrm{a}\mathrm{t}\mathrm{e}\mathrm{n}\mathrm{t}}(G,\ E)=\mathrm{E}_{\mathrm{A}\sim p(\mathrm{A}),\mathrm{z}\sim p(\mathrm{z})}||\mathrm{z}-E(G(\mathrm{A},\ \mathrm{z}))||_{1}$$
<br>
$\hat{\mathrm{z}}=E(G(\mathrm{A},\ \mathrm{z}))$ is encouraged to be close to the randomly drawn $\mathrm{z}$ to enable bijective mapping.
<br><br>

The discriminator loss $L_{\mathrm{G}\mathrm{A}\mathrm{N}}(G,\ D)$  on $\hat{\mathrm{B}}$ is used to encourage the network to generate realistic results.

<hr/>

<br><br>
<h2>Hybrid Model: BicycleGAN</h2>
<br><br>
Combine the cVAE-GAN and cLR-GAN objectives in the hybrid model.<br><br>
Training is done in both directions, aiming to take advantage of both cycles<br> 
($\mathrm{B}\rightarrow \mathrm{z}\rightarrow\hat{\mathrm{B}}$ and $\mathrm{z}\rightarrow\hat{\mathrm{B}}\rightarrow\hat{\mathrm{z}}$), hence the name BicycleGAN.<br><br>

<h3>Combined Objective</h3>

<br>

$$
G^{*},\ E^{*}=\arg\min_{G,E}\max\ \mathcal{L}_{\mathrm{G}\mathrm{A}\mathrm{N}}^{\mathrm{V}\mathrm{A}\mathrm{E}}(G,\ D,\ E)+\lambda \mathcal{L}_{1^{\mathrm{A}\mathrm{E}}}(G,\ E)
+\mathcal{L}_{\mathrm{G}\mathrm{A}\mathrm{N}}(G,\ D)+\lambda_{\mathrm{l}\mathrm{a}\mathrm{t}\mathrm{e}\mathrm{n}\mathrm{t}}\mathcal{L}_{1}^{\mathrm{l}\mathrm{a}\mathrm{t}\mathrm{e}\mathrm{n}\mathrm{t}}(G,\ E)+\lambda_{\mathrm{K}\mathrm{L}}\mathcal{L}_{\mathrm{K}\mathrm{L}}(E)\ ,
$$

<br>
where the hyper-parameters $\lambda, \lambda_{\mathrm{l}\mathrm{a}\mathrm{t}\mathrm{e}\mathrm{n}\mathrm{t}}$, and $\lambda_{\mathrm{K}\mathrm{L}}$ control the relative importance of each term.

<hr/>


<br><br>

<h2>Implementation details<h2> 

<br>

<h3>Network architecture</h3>
<br>
For generator G, U-Net is used, which contains an encoder-decoder
architecture, with symmetric skip connections. The architecture has been shown to produce strong results in the unimodal image prediction setting when there is a spatial correspondence between input and output pairs.<br><br>

<img src="img/unet.png" alt="UNET">

<br>

For discriminator D, generally two PatchGAN discriminators at different
scales are used, which aim to predict real vs. fake overlapping image patches.
<br><br><br>
For the encoder E, these two networks are preferred: <br>
(1) $E_{CNN}$: CNN with a few convolutional and downsampling layers .<br>
(2) $E_{Resnet}$: a classifier with several residual block .

<hr/>

<br><br>

<h2>Implementation in Pytorch<h2>

<br>

<h3>Dataset : Edges2Shoes</h3>

<br>

![Edges2Shoes](img/edgesToShoes.png)


<br><br>
An example of training data:
<br>

![training-data](img/train.jpg)

<br><br>
An example of validation data:
<br>

![validation-data](img/val.jpg)

<br>




<h4>Let us first import the required libraries and modules</h4>

In [None]:
import os
import sys
import glob
import math
import time
import random
import itertools
import datetime
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.utils import save_image
import torchvision.transforms as transforms

<h4>Custom Dataloader</h4>

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, input_shape, mode="train"):
        self.transform = transforms.Compose(
            [
                transforms.Resize(input_shape[-2:], Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )

        self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)])
        w, h = img.size
        img_A = img.crop((0, 0, w / 2, h))
        img_B = img.crop((w / 2, 0, w, h))

        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

    def __len__(self):
        return len(self.files)

<h4>Function to initialize weights for convolution and batchnorm layers</h4>

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

<h3>UNetDown Block</h3>
<br>(Block for downsampling layers of Unet)<br>
<br>Small unit block consists of  (convolution layer - normalization layer - non linearity layer)<br>


    Parameters

    1. in_size : Input dimension(channels number) 
    2. out_size : Output dimension(channels number)
    3. normalize : If it is true add Batch Normalization layer, otherwise skip this layer
    4. dropout : probability for dropping a unit
    

In [None]:
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 3, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size, 0.8))
        layers.append(nn.LeakyReLU(0.2))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

<h3>UNetUp Block</h3>
<br>(Block for Upsampling layers of Unet)<br>
<br>Small unit block consists of (upsampling layer - convolution layer- normalization layer - non linearity layer)<br>
    
    Parameters

    1. in_dim : Input dimension(channels number)
    2. out_dim : Output dimension(channels number)
    3. skip_input (in forward method) : skip connection from corresponding downsampling layer

In [None]:
class UNetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(UNetUp, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_size, out_size, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_size, 0.8),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

<h2>Generator</h2>

> U-Net Generator 

<br>
Downsampled activation volume and upsampled activation volume which have same width and height make pairs and they are concatenated when upsampling.<br>

    Pairs : (up_1, down_6)
            (up_2, down_5)  
            (up_3, down_4) 
            (up_4, down_3) 
            (up_5, down_2) 
            (up_6, down_1)
            down_7 doesn't have a partener.




In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        channels, self.h, self.w = img_shape

        self.fc = nn.Linear(latent_dim, self.h * self.w)

        self.down1 = UNetDown(channels + 1, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512)
        self.down5 = UNetDown(512, 512)
        self.down6 = UNetDown(512, 512)
        self.down7 = UNetDown(512, 512, normalize=False)
        self.up1 = UNetUp(512, 512)
        self.up2 = UNetUp(1024, 512)
        self.up3 = UNetUp(1024, 512)
        self.up4 = UNetUp(1024, 256)
        self.up5 = UNetUp(512, 128)
        self.up6 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(128, channels, 3, stride=1, padding=1), nn.Tanh()
        )

    def forward(self, x, z):
        # Propogate noise through fc layer and reshape to img shape
        #x:(N,3,128,128) z:(N,8)
        z = self.fc(z).view(z.size(0), 1, self.h, self.w)#z:(N,1,128,128)
        
        #concating (x and z): (N,4,128,128)
        d1 = self.down1(torch.cat((x, z), 1)) #d1:(N,64,64,64)
        d2 = self.down2(d1)         #d2:(N,128,32,32)
        d3 = self.down3(d2)         #d3:(N,256,16,16)
        d4 = self.down4(d3)         #d4:(N,512,8,8)
        d5 = self.down5(d4)         #d5:(N,512,4,4)
        d6 = self.down6(d5)         #d6:(N,512,2,2)
        d7 = self.down7(d6)         #d7:(N,512,1,1)
        u1 = self.up1(d7, d6)       #u1:(N,1024,2,2)
        u2 = self.up2(u1, d5)       #u2:(N,1024,4,4)
        u3 = self.up3(u2, d4)       #u3:(N,1024,8,8)
        u4 = self.up4(u3, d3)       #u4:(N,512,16,16)
        u5 = self.up5(u4, d2)       #u5:(N,256,32,32)
        u6 = self.up6(u5, d1)       #u6:(N,128,64,64)

        return self.final(u6)       #final:(N,3,128,128)

<h2>MultiDiscriminator </h2>
<br>
  It uses multiple discriminators, which return different output sizes (i.e. different local probabilities)<br>


    disc_1 : (N, channels, 128, 128) -> (N, 1, 8, 8)
    disc_2 : (N, channels, 64, 64) -> (N, 1, 4, 4)
    disc_3 : (N, channels, 32, 32) -> (N, 1, 2, 2)


In training, the generator needs to fool all the discriminators and it makes the generator more robust.

In [None]:
class MultiDiscriminator(nn.Module):
    def __init__(self, input_shape):
        super(MultiDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        channels, _, _ = input_shape
        # Extracts discriminator models
        self.models = nn.ModuleList()
        for i in range(3):
            self.models.add_module(
                "disc_%d" % i,
                nn.Sequential(
                    *discriminator_block(channels, 64, normalize=False),
                    *discriminator_block(64, 128),
                    *discriminator_block(128, 256),
                    *discriminator_block(256, 512),
                    nn.Conv2d(512, 1, 3, padding=1)
                ),
            )

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def compute_loss(self, x, gt):
        """Computes the MSE between model output and scalar gt"""
        loss = sum([torch.mean((out - gt) ** 2) for out in self.forward(x)])
        return loss

    def forward(self, x):
        outputs = []
        for m in self.models:
            outputs.append(m(x))
            x = self.downsample(x)
        return outputs


<h2>Encoder</h2>
<br>
Output is mu and log(var) for reparameterization trick used in Variation Auto Encoder.<br>Encoding is done in this order.


    1. Use this encoder and get mu and log_var
    2. std = exp(log(var / 2))
    3. random_z = N(0, 1)
    4. encoded_z = random_z * std + mu (Reparameterization trick)


In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim, input_shape):
        super(Encoder, self).__init__()
        resnet18_model = resnet18(pretrained=False)
        self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:-3])
        self.pooling = nn.AvgPool2d(kernel_size=8, stride=8, padding=0)
        # Output is mu and log(var) for reparameterization trick used in VAEs
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

    def forward(self, img):
        #img : (N, 3, 128, 128)
        out = self.feature_extractor(img)  # out : (N, 256, 8, 8)
        out = self.pooling(out)            # out : (N, 256, 1, 1)
        out = out.view(out.size(0), -1)    # out : (N, 256)
        mu = self.fc_mu(out)               # mu : (N, latent_dim)
        logvar = self.fc_logvar(out)       # logvar : (N, latent_dim)
        return mu, logvar


<h2>Reparameterization trick</h2>
<br>If we don't perform reparameterization and simply take a sample from the distribution  $\mathcal{N}(\mu,\,\sigma^{2})\,$ , then we cannot backprop the error through the layer that samples z from the distribution as it is a NON-Continuous operation. Hence reparameterization comes to role

In [None]:
def reparameterization(mu, logvar):
    std = torch.exp(logvar / 2)
    sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), latent_dim))))
    z = sampled_z * std + mu
    return z
    # z : (N, latent_dim)

<h3> Assigning default values to arguments and initializing .</h3>

In [None]:
epoch = 0                      #epoch to start training from
n_epochs = 10                 #number of epochs of training
dataset_name = "edges2shoes"   #name of the dataset
batch_size = 8                 #size of the batches
lr = 0.0002                    #adam: learning rate
b1 = 0.5                       #adam: decay of first order momentum of gradient
b2 = 0.999                     #adam: decay of second order momentum of gradient
n_cpu = 8                      #number of cpu threads to use during batch generation
img_height = 128               #size of image height
img_width = 128                #size of image width
channels = 3                   #number of image channels
latent_dim = 8                 #number of latent codes
sample_interval = 400          #interval between saving generator samples
checkpoint_interval = -1       #interval between model checkpoints
lambda_pixel = 10              #pixelwise loss weight
lambda_latent = 0.5            #latent loss weight
lambda_kl = 0.01               #kullback-leibler loss weight
mae_loss = torch.nn.L1Loss()   #Mean Absolute error loss


input_shape = (channels, img_height, img_width)       #shape of input image (tuple)

cuda = True if torch.cuda.is_available() else False   #availability of GPU
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor


generator = Generator(latent_dim, input_shape)    #Initialize generator
encoder = Encoder(latent_dim, input_shape)        #Initialize encoder
D_VAE = MultiDiscriminator(input_shape)           #initialize discriminators
D_LR = MultiDiscriminator(input_shape)


if cuda:
    generator = generator.cuda()
    encoder.cuda()
    D_VAE = D_VAE.cuda()
    D_LR = D_LR.cuda()
    mae_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_normal)
    D_VAE.apply(weights_init_normal)
    D_LR.apply(weights_init_normal)

<h3>Making the directory where output images will get saved

In [None]:
os.makedirs("images/%s" % dataset_name, exist_ok=True)

<h3>Optimizers : </h3>

In [None]:
optimizer_E = torch.optim.Adam(encoder.parameters(), lr=lr, betas=(b1, b2))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_VAE = torch.optim.Adam(D_VAE.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_LR = torch.optim.Adam(D_LR.parameters(), lr=lr, betas=(b1, b2))

In [None]:
dataloader = DataLoader(
    ImageDataset("../../data/%s" % dataset_name, input_shape),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)
val_dataloader = DataLoader(
    ImageDataset("../../data/%s" % dataset_name, input_shape, mode="val"),
    batch_size=8,
    shuffle=True,
    num_workers=1,
)

<h3>Saves a generated sample from the validation set</h3>

In [None]:
def sample_images(batches_done):

    generator.eval()
    imgs = next(iter(val_dataloader))
    img_samples = None
    for img_A, img_B in zip(imgs["A"], imgs["B"]):

        # Repeat input image by number of desired columns
        real_A = img_A.view(1, *img_A.shape).repeat(latent_dim, 1, 1, 1)
        real_A = Variable(real_A.type(Tensor))

        # Sample latent representations
        sampled_z = Variable(Tensor(np.random.normal(0, 1, (latent_dim, latent_dim))))
        # Generate samples
        fake_B = generator(real_A, sampled_z)
        # Concatenate samples horisontally
        fake_B = torch.cat([x for x in fake_B.data.cpu()], -1)
        img_sample = torch.cat((img_A, fake_B), -1)
        img_sample = img_sample.view(1, *img_sample.shape)
        # Concatenate with previous samples vertically
        img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)
    save_image(img_samples, "images/%s/%s.png" % (dataset_name, batches_done), nrow=8, normalize=True)
    generator.train()

<h3>TRAINING</h3>

In [None]:
# Adversarial loss
valid = 1
fake = 0

prev_time = time.time()
for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # -------------------------------
        #  Train Generator and Encoder
        # -------------------------------

        optimizer_E.zero_grad()
        optimizer_G.zero_grad()

        # ----------
        # cVAE-GAN
        # ----------

        # Produce output using encoding of B (cVAE-GAN)
        mu, logvar = encoder(real_B)
        encoded_z = reparameterization(mu, logvar)
        fake_B = generator(real_A, encoded_z)

        # Pixelwise loss of translated image by VAE
        loss_pixel = mae_loss(fake_B, real_B)
        # Kullback-Leibler divergence of encoded B
        loss_kl = 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - logvar - 1)
        # Adversarial loss
        loss_VAE_GAN = D_VAE.compute_loss(fake_B, valid)

        # ---------
        # cLR-GAN
        # ---------

        # Produce output using sampled z (cLR-GAN)
        sampled_z = Variable(Tensor(np.random.normal(0, 1, (real_A.size(0), latent_dim))))
        _fake_B = generator(real_A, sampled_z)
        # cLR Loss: Adversarial loss
        loss_LR_GAN = D_LR.compute_loss(_fake_B, valid)

        # ----------------------------------
        # Total Loss (Generator + Encoder)
        # ----------------------------------

        loss_GE = loss_VAE_GAN + loss_LR_GAN + lambda_pixel * loss_pixel + lambda_kl * loss_kl

        loss_GE.backward(retain_graph=True)
        optimizer_E.step()

        # ---------------------
        # Generator Only Loss
        # ---------------------

        # Latent L1 loss
        _mu, _ = encoder(_fake_B)
        loss_latent = lambda_latent * mae_loss(_mu, sampled_z)

        loss_latent.backward()
        optimizer_G.step()

        # ----------------------------------
        #  Train Discriminator (cVAE-GAN)
        # ----------------------------------

        optimizer_D_VAE.zero_grad()

        loss_D_VAE = D_VAE.compute_loss(real_B, valid) + D_VAE.compute_loss(fake_B.detach(), fake)

        loss_D_VAE.backward()
        optimizer_D_VAE.step()

        # ---------------------------------
        #  Train Discriminator (cLR-GAN)
        # ---------------------------------

        optimizer_D_LR.zero_grad()

        loss_D_LR = D_LR.compute_loss(real_B, valid) + D_LR.compute_loss(_fake_B.detach(), fake)

        loss_D_LR.backward()
        optimizer_D_LR.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D VAE_loss: %f, LR_loss: %f] [G loss: %f, pixel: %f, kl: %f, latent: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D_VAE.item(),
                loss_D_LR.item(),
                loss_GE.item(),
                loss_pixel.item(),
                loss_kl.item(),
                loss_latent.item(),
                time_left,
            )
        )

        if batches_done % sample_interval == 0:
            sample_images(batches_done)


[Epoch 4/10] [Batch 2860/6229] [D VAE_loss: 0.944683, LR_loss: 1.485869] [G loss: 4.076747, pixel: 0.139999, kl: 10.436765, latent: 0.338610] ETA: 11:49:17.409138

<h2>Results (Output Images):</h2>
<br><br>
<h4>After 0 batches </h4>
<br>

![Result_0](img/0.png)

<br>
<h4>After 1200 batches </h4>
<br>

![Result_1200](img/1200.png)

<br>
<h4>After 3200 batches </h4>
<br>

![Result_3200](img/3200.png)

<br>
<h4>After 8800 batches </h4>
<br>

![Result_8800](img/8800.png)

<br>
<h4>After 11600 batches </h4>
<br>

![Result_11600](img/11600.png)

<br>
<h4>After 18800 batches </h4>
<br>

![Result_18800](img/18800.png)

<br>
<h4>After 20400 batches </h4>
<br>

![Result_20400](img/20400.png)

<h3>References</h3>
<br><br>

@misc { zhu2018multimodal ,<br>
&emsp;&emsp;&emsp;&emsp;title = { Toward Multimodal Image-to-Image Translation }, <br>
&emsp;&emsp;&emsp;&emsp;author = { Jun-Yan Zhu and Richard Zhang and Deepak Pathak and Trevor Darrell and Alexei A. Efros and Oliver Wang and Eli Shechtman },<br>
&emsp;&emsp;&emsp;&emsp;year = { 2018 },<br>
&emsp;&emsp;&emsp;&emsp;eprint = { 1711.11586 },<br>
&emsp;&emsp;&emsp;&emsp;archivePrefix = { arXiv },<br>
&emsp;&emsp;&emsp;&emsp;primaryClass = { cs.CV }<br>
}
<br><br>
* https://github.com/eriklindernoren/PyTorch-GAN