<a href="https://colab.research.google.com/github/zeligism/ConGAN/blob/main/GAN_%2B_SimSiam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Mount drive

In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Header

## Imports

In [14]:
import os
import glob
import random
import datetime
import yaml
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import torchvision
import torchvision.utils as vutils
import torchvision.transforms as transforms
import torch.utils.tensorboard as tensorboard

from PIL import Image, ImageDraw
from math import log2
from pprint import pformat
from collections import defaultdict

## Utility Functions

### Report Utils

In [15]:
def plot_lines(losses_dict, filename=None, title=""):
    """
    Plots the losses of the discriminator and the generator.

    Args:
        filename: The plot's filename. If None, plot won't be saved.
    """

    plt.figure(figsize=(10,5))
    plt.title(title)
    for label, losses in losses_dict.items():
        plt.plot(losses, label=label)
    plt.xlabel("t")
    plt.legend()
    
    if filename is not None:
        plt.savefig(filename)
    
    plt.show()
    plt.close()


def create_progress_animation(frames, filename):
    """
    Creates a video of the progress of the generator on a fixed latent vector.

    Args:
        filename: The animation's filename.
    """

    fig = plt.figure(figsize=(8,8))
    plt.axis("off")
    ims = [[plt.imshow(img.permute(1,2,0), animated=True)]
           for img in frames]
    ani = animation.ArtistAnimation(fig, ims, blit=True)
    
    ani.save(filename)

    plt.close()


def generate_grid(generator, latent):
    """
    Check generator's output on latent vectors and return it.

    Args:
        generator: The generator.
        latent: Latent vector from which an image grid will be generated.

    Returns:
        A grid of images generated by `generator` from `latent`.
    """

    with torch.no_grad():
        fake = generator(latent).detach()

    image_grid = vutils.make_grid(fake.cpu(), padding=2, normalize=True, range=(-1,1))

    return image_grid


def generate_G_grid(generator, before):
    """
    Generate a grid of pairs of images, where each pair shows a before-after
    transition when applying G on before.
    """

    if len(before.size()) == 3:
        before.unsqueeze(0)

    batch_size = before.size()[0]
    img_dim = before.size()[1:]

    with torch.no_grad():
        after = generator(before)

    row = torch.zeros([2 * batch_size, *img_dim])
    row[0::2] = before.detach()
    row[1::2] = after.detach()

    image_grid = vutils.make_grid(row.cpu(), nrow=8, padding=2, normalize=True, range=(-1,1))

    return image_grid


def generate_makeup_grid(applier_ref, remover, before, after_ref):
    """
    Generate a grid, 8 images per row, as follows:
      Image #1: real photo of a face WITHOUT makeup (call it face #1).
      Image #2: real (makeup reference) photo of a face WITH makeup (call it face #2).
      Image #3: fake photo of face #1 WITH makeup style from face #2 (applied).
      Image #4: fake photo of face #2 WITHOUT makeup (removed).
      Image #5: Repeat the same pattern from Image #1...

    In case only 4 images are needed per row, change `nrow` below to 4.
    """

    if len(before.size()) == 3:
        before.unsqueeze(0)

    batch_size = before.size()[0]
    img_dim = before.size()[1:]

    with torch.no_grad():
        fake_after = applier_ref(before, after_ref)
        fake_before_ref = remover(after_ref)


    row = torch.zeros([4 * batch_size, *img_dim])
    row[0::4] = before.detach()
    row[1::4] = after_ref.detach()
    row[2::4] = fake_after.detach()
    row[3::4] = fake_before_ref.detach()

    image_grid = vutils.make_grid(row.cpu(), nrow=8, padding=2, normalize=True, range=(-1,1))

    return image_grid


# Classes

## PyTorch Modules

### DCGAN-style

In [16]:
class ConvBlock(nn.Module):
    """
    Default stride and padding half the size of features,
    e.g. if input is [in_channels, 64, 64], output will be [out_channels, 32, 32].
    """
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1,
                 use_batchnorm=True, use_spectralnorm=False, activation=None):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, bias=False)
        if use_spectralnorm:
            self.conv = nn.utils.spectral_norm(self.conv)
        self.batchnorm = nn.BatchNorm2d(out_channels) if use_batchnorm else None
        self.activation = nn.LeakyReLU(0.2, inplace=True) if activation is None else activation()

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.batchnorm(x)
        x = self.activation(x)
        return x


class ConvTBlock(nn.Module):
    """
    Default stride and padding double the size of features,
    e.g. if input is [in_channels, 32, 32], output will be [out_channels, 64, 64].
    """
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1,
                 use_batchnorm=True, use_spectralnorm=False, activation=None):
        super().__init__()

        self.convT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
                                        stride=stride, padding=padding, bias=False)
        if use_spectralnorm:
            self.convT = nn.utils.spectral_norm(self.convT)
        self.batchnorm = nn.BatchNorm2d(out_channels) if use_batchnorm else None
        self.activation = nn.ReLU(inplace=True) if activation is None else activation()

    def forward(self, x):
        x = self.convT(x)
        if self.batchnorm:
            x = self.batchnorm(x)
        x = self.activation(x)
        return x


class DCGAN_Discriminator(nn.Module):
    def __init__(self,
                 num_latents=16,
                 num_features=64,
                 image_channels=3,
                 image_size=64,
                 feature_multiplier=2,
                 max_features=512,
                 gan_type="gan",
                 fully_convolutional=True,
                 activation=None,
                 use_batchnorm=True,
                 use_spectralnorm=False,
                 D_block=ConvBlock):
        super().__init__()

        using_grad_penalty = gan_type in ("gan-gp", "wgan-gp")
        output_sigmoid = gan_type in ("gan", "gan-gp")

        block_config = {
            "activation": activation,
            "use_batchnorm": use_batchnorm and not using_grad_penalty,
            "use_spectralnorm": use_spectralnorm,
        }

        # Calculate intermediate image sizes
        image_sizes = [image_size]
        while image_sizes[-1] > 5:
            image_sizes.append(image_sizes[-1] // 2)
        latent_kernel = image_sizes[-1]  # should be either 3, 4, or 5
        num_layers = len(image_sizes) - 1

        # Calculate feature sizes
        features = [min(max_features, round(num_features * feature_multiplier**i))
                    for i in range(num_layers)]

        # Input layer
        self.input_layer = D_block(image_channels, features[0], **block_config)

        # Intermediate layers
        self.main_layers = nn.Sequential(*[
            D_block(in_features, out_features, **block_config)
            for in_features, out_features in zip(features, features[1:])
        ])

        # Output layer (feature_size = 3, 4, or 5 -> 1)
        if fully_convolutional:
            self.output_layer = nn.Sequential(
                nn.Conv2d(features[-1], num_latents, latent_kernel, bias=False),
                nn.Flatten(),
            )
        else:
            self.output_layer = nn.Sequential(
                nn.Flatten(),
                nn.Linear(features[-1] * latent_kernel**2, num_latents, bias=False)
            )

        # Add sigmoid activation if using regular GAN loss
        self.output_activation = nn.Sigmoid() if output_sigmoid else None

    def forward(self, x):
        x = self.input_layer(x)
        x = self.main_layers(x)
        x = self.output_layer(x)
        if self.output_activation:
            x = self.output_activation(x)
        # Remove H and W dimensions, infer channels dim (remove if 1)
        x = x.view(x.size(0), -1).squeeze(1)
        return x


class DCGAN_Generator(nn.Module):
    def __init__(self,
                 num_latents=100,
                 num_features=64,
                 image_channels=3,
                 image_size=64,
                 feature_multiplier=2,
                 max_features=512,
                 gan_type="gan",
                 fully_convolutional=True,
                 activation=None,
                 use_batchnorm=True,
                 use_spectralnorm=False,
                 G_block=ConvTBlock):
        super().__init__()

        block_config = {
            "activation": activation,
            "use_batchnorm": use_batchnorm,
            "use_spectralnorm": use_spectralnorm
        }

        # Calculate intermediate image sizes
        image_sizes = [image_size]
        while image_sizes[-1] > 5:
            image_sizes.append(image_sizes[-1] // 2)
        latent_kernel = image_sizes[-1]  # should be either 3, 4, or 5
        num_layers = len(image_sizes) - 1

        # Calculate feature sizes
        features = [min(max_features, round(num_features * feature_multiplier**i))
                    for i in range(num_layers)]

        # Reverse order of image sizes and features for generator
        image_sizes = image_sizes[::-1]
        features = features[::-1]

        # Input layer
        if fully_convolutional:
            self.input_layer = G_block(num_latents, features[0], kernel_size=latent_kernel,
                                       stride=1, padding=0, **block_config)
        else:
            self.input_layer = nn.Sequential(
                nn.Flatten(),
                nn.Linear(num_latents, features[0] * image_sizes[0]**2, bias=False),
                View(features[0], image_sizes[0], image_sizes[0])
            )

        # Intermediate layers
        self.main_layers = nn.Sequential(*[
            G_block(in_features, out_features, kernel_size=4+(expected_size%2), **block_config)
            for in_features, out_features, expected_size in zip(features, features[1:], image_sizes[1:])
        ])

        # Output layer
        self.output_layer = nn.ConvTranspose2d(features[-1], image_channels, kernel_size=4+(image_size%2),
                                               stride=2, padding=1, bias=False)
        self.output_activation = nn.Tanh()

    def forward(self, x):
        # Add H and W dimensions, infer channels dim (add if none)
        x = x.view(x.size(0), -1, 1, 1)
        x = self.input_layer(x)
        x = self.main_layers(x)
        x = self.output_layer(x)
        x = self.output_activation(x)
        return x


class View(nn.Module):
    def __init__(self, *shape, including_batch=False):
        super().__init__()
        self.shape = shape
        self.including_batch = including_batch
    
    def forward(self, x):
        if self.including_batch:
            return x.view(*self.shape)
        else:
            return x.view(x.size(0), *self.shape)


### Residual

In [None]:
class ChannelNoise(nn.Module):
    """
    Channel noise injection module.
    Adds a linearly transformed noise to a convolution layer.
    """

    def __init__(self, num_channels, std=0.02):
        super().__init__()
        self.std = std
        self.scale = nn.Parameter(torch.ones(1, num_channels, 1, 1))


    def forward(self, x):
        noise_size = [x.size()[0], 1, *x.size()[2:]]  # single channel
        noise = self.std * torch.randn(noise_size).to(x)

        return x + self.scale * noise


class ResidualBlock(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 dilation=(1,1),
                 downsample=None,
                 dropout_p=0.0):
        super().__init__()

        self.dilation = dilation
        self.downsample = downsample

        self.main = nn.Sequential(
            ### Conv 3x3 ###
            nn.Conv2d(in_channels, out_channels, 3,
                      padding=dilation[0], dilation=dilation[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            ChannelNoise(out_channels),
            ### Conv 3x3 ###
            nn.Conv2d(out_channels, out_channels, 3,
                      padding=dilation[1], dilation=dilation[1], bias=False),
            nn.BatchNorm2d(out_channels),
        )


    def forward(self, x):

        residual = x if self.downsample is None else self.downsample(x)

        return F.relu(self.main(x) + residual)


class ResidualBottleneck(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 downsample=None,
                 dilation=1,
                 dropout_p=0.0):
        super().__init__()

        self.downsample = downsample
        self.dilation = dilation

        self.main = nn.Sequential(

            ### Conv 1x1 ###
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(p=dropout_p),

            ### Conv 3x3 ###
            nn.Conv2d(out_channels, out_channels, 3,
                      padding=dilation[1], dilation=dilation[1], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(p=dropout_p),

            ### Conv 1x1 ###
            nn.Conv2d(out_channels, out_channels * 4, 1, bias=False),
            nn.BatchNorm2d(out_channels * 4),
        )


    def forward(self, x):

        residual = x if self.downsample is None else self.downsample(x)

        return F.relu(self.main(x) + residual)



### MaskGAN

In [None]:
class MaskGAN(nn.Module):
    def __init__(self,
                 num_features=64,
                 max_features=512,
                 image_channels=3,
                 image_size=64,
                 feature_multiplier=2,
                 gan_type="gan",
                 with_reference=False):
        super().__init__()

        D_params = {
            "num_features": num_features,
            "max_features": max_features,
            "image_channels": image_channels,
            "image_size": image_size,
            "feature_multiplier": feature_multiplier,
            "gan_type": gan_type,
        }
        G_params = {
            "num_features": num_features,
            "with_reference": with_reference,
        }

        self.D = DCGAN_Discriminator(**D_params)
        self.G = MaskGenerator(**G_params)


class MaskGenerator(nn.Module):
    """A neural network that generates a mask to apply."""
    def __init__(self, num_features=64, with_reference=False):
        super().__init__()

        self.num_features = num_features
        self.with_reference = with_reference

        def make_features_extractor(num_features):
            return nn.Sequential(
                nn.Conv2d(3, num_features, 7, padding=3, bias=False),
                nn.ReLU(),
            )

        # Extract features from source
        self.source_features_extractor = make_features_extractor(self.num_features)

        # Extract features from reference
        if self.with_reference:
            self.reference_features_extractor = make_features_extractor(self.num_features)


        # Double the number of features in the mask generator if with reference
        if self.with_reference:
            num_features *= 2

        self.mask_generator = nn.Sequential(
            ResidualBlock(num_features, num_features),
            ResidualBlock(num_features, num_features, dilation=(2,2)),
            ResidualBlock(num_features, num_features, dilation=(4,4)),
            ResidualBlock(num_features, num_features, dilation=(8,8)),
            nn.Conv2d(num_features, num_features, 3, padding=2, dilation=2, bias=False),
            nn.ReLU(),
            nn.Conv2d(num_features, 3, 3, padding=1, bias=False),
            nn.Tanh(),
        )


    def forward(self, source, reference=None):

        assert reference is None or self.with_reference

        features = self.source_features_extractor(source)

        if self.with_reference:
            reference_features = self.reference_features_extractor(reference)
            features = torch.cat([features, reference_features], dim=1)

        mask = self.mask_generator(features)

        return (source + mask).clamp(-1,1) # XXX: range could go outside [-1, 1] !!!



### ConsistentGAN

In [None]:
class ConsistentGAN(nn.Module):
    def __init__(self,
                 base_encoder,
                 repr_dim,
                 pred_dim,
                 latent_dim,
                 args*, **kwargs):
        """
        s ~ S is representation/encoding space, e.g. s = Enc(x).
        z ~ Z is latent/seed space for G, e.g. x_fake ~ G(z).

        Case A:
            - Sample x_real ~ X
            - s <- Encoder(x_real)
            - Sample z ~ Z (coupled with s somehow? e.g. z = f(s) + noise)
            - x_fake <- G(z)
            - Test D on x_real, x_fake, i.e. test whether x in X}
            - Calculate GAN loss
            - Calculate SimSiam loss with Predictor(s)

        Case B: 
            - Sample x_real ~ X
            - s_real <- Encoder(x_real)
            - x_recon <- Decoder(s_real)
            - Sample z ~ Z (e.g. N(0,1) or U(-1,1))
            - s_fake <- G(z)
            - Test D on s_real, s_fake, i.e. test whether Dec(s) in X (how?)
            - Calculate GAN loss
            - Calculate Contrastive/SimSiam loss with Predictor(s)
            - Calculate reconsruction loss, e.g. || Enc(x_recon) - Enc(x_real) ||
        
        We choose Case B for now.

        Adversarial learning:
            Case A:
                Train D and G so that we can disciminate x based on its representation.
                Train G so that representation has enough info for reconstructing x.
                Ideally, G would generate different views of x given similar s's.
            Case B:
                Train D and G so that we can disciminate representations.
                Train G so that it produces representations as real (close to Enc(x)) as possible.
                This assumes that Enc(x) is the real representation.
                (There is no such thing as any random transformation can be real enough, but
                we should try to mitigate this problem nonetheless. For example, we can at least
                make it stable enough by pre-training Enc/Dec.)
                Ideally, G would produce accurate representations that can be decoded later.

        Contrastive learning:
            If x1 and x2 are views of same x (e.g. x1, x2 = rand_aug(x), rand_aug(x))
            Then, Encoder(x1) should be similar to Encoder(x2), so we do this (SimSiam algorithm):
            min 0.5*{ sim(Predictor(s1), s2.detach()) + sim(Predictor(s2), s1.detach()) }
            where sim = CosineSimilarity(dim=1).
        
        Reconstruction learning:
            We can add a (deterministic) Decoder that learns to decode s to x.
            This does not take into account the invariance of representation to augmentations.
            It simply learns the inverse of Encoder. G will learn to produce represenations
            that are as real as possible, where the Decoder will learn to decode them into their
            corresponding x's such that Enc(Dec(s)) = s (do we stop grad at s for this loss?)

        Architecture:
            We assume x comes from an image dataset, e.g. CIFAR10.
            Encoder-Decoder pair can have a DCGAN-style architecture.
            The generator can also have an arch similar to the decoder.
            Other networks can be simple FCNs.
        """
        super().__init__()
        self.image_size = image_size
        self.repr_dim = repr_dim
        self.latent_dim = latent_dim
        self.pred_dim = pred_dim

        ### Copied from SimSiam repo >>>>>>>>
        # create the encoder
        # num_classes is the output fc dimension, zero-initialize last BNs
        self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)

        # build a 3-layer projector
        prev_dim = self.encoder.fc.weight.shape[1]
        self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # first layer
                                        nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # second layer
                                        self.encoder.fc,
                                        nn.BatchNorm1d(dim, affine=False)) # output layer
        self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN

        # build a 2-layer predictor
        self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True), # hidden layer
                                        nn.Linear(pred_dim, dim)) # output layer
        ### <<<<<<<<
        
        # Make D's architecture kinda similar to predictor @TODO
        # @XXX: Shouldn't use batchnorm with grad_penalty
        D_hidden_dim = repr_dim // 2
        self.D = nn.Sequential(nn.Linear(repr_dim, D_hidden_dim, bias=False),
                               nn.BatchNorm1d(D_hidden_dim),
                               nn.LeakyRelu(0.2, inplace=True),
                               nn.Linear(D_hidden_dim, D_hidden_dim, bias=False),
                               nn.BatchNorm1d(D_hidden_dim),
                               nn.LeakyRelu(0.2, inplace=True),
                               nn.Linear(D_hidden_dim, D_hidden_dim, bias=False),
                               nn.BatchNorm1d(D_hidden_dim),
                               nn.LeakyRelu(0.2, inplace=True),
                               nn.Linear(D_hidden_dim, 1))
        
        # Same for generator (latent -> representations)
        self.G = nn.Sequential(nn.Linear(latent_dim, repr_dim, bias=False),
                               nn.BatchNorm1d(repr_dim),
                               nn.Relu(inplace=True), # hidden layer
                               nn.Linear(repr_dim, repr_dim, bias=False),
                               nn.BatchNorm1d(repr_dim),
                               nn.Relu(inplace=True), # hidden layer
                               nn.Linear(repr_dim, repr_dim, bias=False),
                               nn.BatchNorm1d(repr_dim),
                               nn.Relu(inplace=True), # hidden layer
                               nn.Linear(repr_dim, repr_dim))

        # Decoder should be a ConvT net. We'll use DCGAN's G for now
        #self.decoder = DCGAN_Generator(num_latents=repr_dim, image_size=image_size) #@XXX
    

    def sample_latent(self, batch_size):

        latent_size = [batch_size, self.latent_dim]
        latent = torch.randn(latent_size)

        return latent

    def forward(self, x1, x2):
        """
        Input:
            x1: first views of images
            x2: second views of images
        Output:
            p1, p2, z1, z2: predictors and targets of the network
            See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations
        """

        # compute features for one view
        z1 = self.encoder(x1) # NxC
        z2 = self.encoder(x2) # NxC

        p1 = self.predictor(z1) # NxC
        p2 = self.predictor(z2) # NxC

        return p1, p2, z1.detach(), z2.detach()



## Trainer

#### Init Utils

In [None]:
def create_weights_init(conv_std=0.01, batchnorm_std=0.01):
    """
    A function that returns the weights initialization function for a net,
    which can be used as `net.apply(create_weights_init())`, for example.

    Args:
        conv_std: the standard deviation of the conv/up-conv layers.
        batchnorm_std: the standard deviation of the batch-norm layers.
    """

    def weights_init(module):
        classname = module.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(module.weight.data, 0.0, conv_std)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(module.weight.data, 1.0, batchnorm_std)
            nn.init.constant_(module.bias.data, 0)

    def weights_init_kaiming(module):
        if isinstance(module, nn.Conv2d):
            #nn.init.kaiming_normal_(module.weight, nonlinearity="leaky_relu")
            nn.init.normal_(module.weight, 0.0, conv_std)
        elif isinstance(module, nn.ConvTranspose2d):
            nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)

    return weights_init_kaiming


def init_optim(params, optim_choice="sgd", lr=1e-4, momentum=0.0, betas=(0.9, 0.999)):
    """
    Initializes the optimizer.

    Args:
        params: Parameters the optimizer will optimize.
        choice: The choice of the optimizer.
        optim_configs: Configurations for the optimizer.

    Returns:
        The optimizer (torch.optim).
    """

    if optim_choice == "adam":
        optim = torch.optim.Adam(params, lr=lr, betas=betas)
    elif optim_choice == "adamw":
        optim = torch.optim.AdamW(params, lr=lr, betas=betas)
    elif optim_choice == "rmsprop":
        optim = torch.optim.RMSprop(params, lr=lr)
    elif optim_choice == "sgd":
        optim = torch.optim.SGD(params, lr=lr, momentum=momentum)
    else:
        raise ValueError(f"Optimizer '{optim_choice}' not recognized.")

    return optim


#### GAN Utils

In [None]:
def get_D_loss(gan_type="gan"):
    if gan_type in ("gan", "gan-gp"):
        return D_loss_GAN
    elif gan_type in ("wgan", "wgan-gp"):
        return D_loss_WGAN
    else:
        raise ValueError(f"gan_type {gan_type} not supported")


def get_G_loss(gan_type="gan"):
    if gan_type in ("gan", "gan-gp"):
        return G_loss_GAN
    elif gan_type in ("wgan", "wgan-gp"):
        return G_loss_WGAN
    else:
        raise ValueError(f"gan_type {gan_type} not supported")


def D_loss_GAN(D_on_real, D_on_fake, label_smoothing=True):
    
    # Create (noisy) real and fake labels XXX
    if label_smoothing:
        real_label = 0.7 + 0.5 * torch.rand_like(D_on_real)
    else:
        real_label = torch.ones_like(D_on_real) - 0.1
    fake_label = torch.zeros_like(D_on_fake)

    # Calculate binary cross entropy loss
    D_loss_on_real = F.binary_cross_entropy(D_on_real, real_label)
    D_loss_on_fake = F.binary_cross_entropy(D_on_fake, fake_label)

    # Loss is: - log(D(x)) - log(1 - D(x_g)),
    # which is equiv. to maximizing: log(D(x)) + log(1 - D(x_g))
    D_loss = D_loss_on_real + D_loss_on_fake

    return D_loss.mean()


def D_loss_WGAN(D_on_real, D_on_fake, grad_penalty=0.0):

    # Maximize: D(x) - D(x_g) - const * (|| grad of D(x_i) wrt x_i || - 1)^2,
    # where x_i <- eps * x + (1 - eps) * x_g, and eps ~ rand(0,1)
    D_loss = -1 * (D_on_real - D_on_fake - grad_penalty)

    return D_loss.mean()


def G_loss_GAN(D_on_fake):

    # Calculate binary cross entropy loss with a fake binary label
    fake_label = torch.zeros_like(D_on_fake)

    # Loss is: -log(D(G(z))), which is equiv. to minimizing log(1-D(G(z)))
    # We use this loss vs. the original one for stability only.
    G_loss = F.binary_cross_entropy(D_on_fake, 1 - fake_label)

    return G_loss.mean()


def G_loss_WGAN(D_on_fake):

    # Minimize: -D(G(z))
    G_loss = -D_on_fake
    
    return G_loss.mean()

"""
def get_D_grad_norm(discriminator, real, fake):

    batch_size = real.size()[0]
    device = real.device

    # Calculate gradient penalty
    eps = torch.rand([batch_size, 1, 1, 1], device=device)
    interpolated = eps * real + (1 - eps) * fake
    interpolated.requires_grad_()
    D_on_inter = discriminator(interpolated)

    # Calculate gradient of D(x_i) wrt x_i for each batch
    D_grad = torch.autograd.grad(D_on_inter, interpolated,
                                 torch.ones_like(D_on_inter), create_graph=True)

    # D_grad will be a 1-tuple, as in: (grad,)
    D_grad_norm = D_grad[0].view([batch_size, -1]).norm(dim=1)

    return D_grad_norm


def get_grad_penalty(grad_norm, gp_coeff=10.):
    # D's gradient penalty is `gp_coeff * (|| grad of D(x_i) wrt x_i || - 1)^2`
    grad_penalty = (grad_norm - 1).pow(2) * gp_coeff

    return grad_penalty
"""

def random_interpolate(real, fake):
    eps = torch.rand(real.size(0), 1, 1, 1).to(real)
    return eps * real + (1 - eps) * fake

def simple_gradient_penalty(D, x, center=0.):
    x.requires_grad_()
    D_on_x = D(x)
    D_grad = torch.autograd.grad(D_on_x, x, torch.ones_like(D_on_x), create_graph=True)
    D_grad_norm = D_grad[0].view(x.size(0), -1).norm(dim=1)
    return (D_grad_norm - center).pow(2).mean()


### Base Trainer

In [None]:
class BaseTrainer:
    """The base trainer class."""

    def __init__(self, model, dataset,
        name="trainer",
        results_dir="results/",
        load_model_path=None,
        num_gpu=1,
        num_workers=0,
        batch_size=4,
        report_interval=10,
        save_interval=100000,
        use_tensorboard=False,  # XXX: not implemented yet
        description="no description given",
        **kwargs):
        """
        Initializes BaseTrainer.

        Args:
            model: The model or net.
            dataset: The dataset on which the model will be training.
            name: Name of this trainer.
            results_dir: Directory in which results will be saved for each run.
            load_model_path: Path to the model that will be loaded, if any.
            num_gpu: Number of GPUs to use for training.
            num_workers: Number of workers sampling from the dataset.
            batch_size: Size of the batch. Must be > num_gpu.
            report_interval: Report stats every `report_interval` iters.
            save_interval: Save model every `save_interval` iters.
            description: Description of the experiment the trainer is running.
        """

        self.model = model
        self.dataset = dataset

        self.name = name
        self.results_dir = results_dir
        self.load_model_path = load_model_path

        self.num_gpu = num_gpu
        self.num_workers = num_workers
        self.batch_size = batch_size

        self.report_interval = report_interval
        self.save_interval = save_interval
        self.description = description
        self.save_results = False

        self.start_time = datetime.datetime.now()
        self.stop_time = datetime.datetime.now()
        self.iters = 1  # current iteration (i.e. # of batches processed so far)
        self.batch = 1  # current batch
        self.epoch = 1  # current epoch
        self.num_batches = 1 + len(self.dataset) // self.batch_size  # num of batches per epoch
        self.num_epochs = 0  # number of epochs to run

        self._dataset_sampler = iter(())  # generates samples from the dataset
        self._data = defaultdict(list)  # contains data of experiment

        self.writer = None
        self.use_tensorboard = use_tensorboard

        # Load model if necessary
        if load_model_path is not None:
            self.load_model(load_model_path)

        # Initialize device
        using_cuda = torch.cuda.is_available() and self.num_gpu > 0
        self.device = torch.device("cuda:0" if using_cuda else "cpu")

        # Move model to device and parallelize model if possible
        self.model = self.model.to(self.device)
        if self.device.type == "cuda" and self.num_gpu > 1:
            self.model = torch.nn.DistributedDataParallel(self.model, list(range(self.num_gpu)))


    def load_model(self, model_path):
        if not os.path.isfile(model_path):
            print(f"Couldn't load model: file '{model_path}' does not exist")
            print("Training model from scratch.")
        else:
            print("Loading model...")
            self.model.load_state_dict(torch.load(model_path))


    def save_model(self, model_path):
        print("Saving model...")
        torch.save(self.model.state_dict(), model_path)


    def time_since_start(self):
        elapsed_time = datetime.datetime.now() - self.start_time
        return elapsed_time.total_seconds()


    def run(self, num_epochs, save_results=False):
        """
        Runs the trainer. Trainer will train the model and then save it.
        Note that running trainer more than once will accumulate the results.

        Args:
            num_epochs: Number of epochs to run.
            save_results: A flag indicating whether we should save the results this run.
        """
        self.start_time = datetime.datetime.now()
        self.num_epochs = num_epochs + self.epoch - 1
        self.save_results = save_results

        # Create experiment directory
        experiment_name = self.get_experiment_name()
        experiment_dir = os.path.join(self.results_dir, experiment_name)
        if self.save_results:
            if not os.path.isdir(self.results_dir): os.mkdir(self.results_dir)
            if not os.path.isdir(experiment_dir): os.mkdir(experiment_dir)

        with tensorboard.SummaryWriter(f"runs/{experiment_name}") as self.writer:
            # Try training the model, then stop the training when an exception is thrown
            try:
                self.train()
            finally:
                self.stop_time = datetime.datetime.now()
                self.stop()


    def train(self):
        """
        Train model on dataset for `num_epochs` epochs.

        Args:
            num_epochs: Number of epochs to run.
        """

        # Train until dataset sampler is exhausted (i.e. until it throws StopIteration)
        self.init_dataset_sampler()

        try:
            print(f"Starting training {self.name}...")
            while True:
                # One training step/iteration
                self.pre_train_step()
                self.train_step()
                self.post_train_step()
                self.iters += 1

        except StopIteration:
            print("Finished training.")


    def init_dataset_sampler(self):
        """
        Initializes the sampler (or iterator) of the dataset.

        Args:
            num_epochs: Number of epochs.
        """
        loader_config = {
            "batch_size": self.batch_size,
            "shuffle": True,
            "num_workers": self.num_workers,
        }
        self._dataset_sampler = iter(self.sample_loader(loader_config))


    def sample_loader(self, loader_config):
        """
        A generator that yields samples from the dataset, exhausting it `num_epochs` times.

        Args:
            num_epochs: Number of epochs.
            loader_config: Configuration for pytorch's data loader.
        """

        for self.epoch in range(self.epoch, self.num_epochs + 1):
            data_loader = torch.utils.data.DataLoader(self.dataset, **loader_config)
            for self.batch, sample in enumerate(data_loader, 1):
                yield sample

        self.epoch += 1


    def sample_dataset(self):
        """
        Samples the dataset. To be called by the client.

        Returns:
            A sample from the dataset.
        """
        return next(self._dataset_sampler)


    def pre_train_step(self):
        """
        The training preparation, or what happens before each training step.
        """
        pass


    def train_step(self):
        """
        Makes one training step.
        """
        pass


    def post_train_step(self):
        """
        The training checkpoint, or what happens after each training step.
        """
        should_report_stats = self.iters % self.report_interval == 0
        should_save_progress = self.iters % self.save_interval == 0
        finished_epoch = self.batch == self.num_batches

        # Report training stats
        if should_report_stats or finished_epoch:
            self.report_stats()

        if self.save_results and should_save_progress:
            model_path = os.path.join(self.results_dir,
                                      self.get_experiment_name(),
                                      f"model@{self.iters}.pt")
            self.save_model(model_path)


    def stop(self):
        """
        Stops the trainer, or what happens when the trainer stops.
        Note: This will run even on keyboard interrupts.
        """

        # plot losses, if any
        plot_lines(self.get_data_containing("loss"), title="Losses")


    def get_experiment_name(self, delimiter=", "):
        """
        Get the name of trainer's training train...

        Args:
            delimiter: The delimiter between experiment's parameters. Pretty useless.
        """
        info = {
            "name": self.name,
            "batch_size": self.batch_size,
        }

        timestamp = self.start_time.strftime("%y%m%d-%H%M%S")
        experiment = delimiter.join(f"{k}={v}" for k,v in info.items())

        return "[{}] {}".format(timestamp, experiment)


    def report_stats(self, precision=3):
        """
        Default training stats report.
        Prints the current value of each data list recorded.
        """

        # Progress of training
        progress = f"[{self.epoch}/{self.num_epochs}][{self.batch}/{self.num_batches}]  "

        # Show the stat of an item
        item_stat = lambda item: f"{item[0]} = {item[1][-1]:.{precision}f}"
        # Join the stats separated by tabs
        stats = ",  ".join(map(item_stat, self._data.items()))

        report = progress + stats

        print(report)


    def get_current_value(self, label):
        """
        Get the current value of the quantity given by `label`.

        Args:
            label: Name/label of the data/quantity.

        Returns:
            The current value of the quantity given by `label`.
        """
        return self._data[label][-1] if len(self._data[label]) > 0 else None


    def get_data_containing(self, phrase):
        """
        Get the data lists that contain `phrase` in their names/labels.

        Args:
            phrase: A phrase to find in the label of the data, such as "loss".

        Returns:
            A dict containing the data lists that contain `phrase` in their labels.
        """
        return {k: v for k, v in self._data.items() if k.find(phrase) != -1}


    def add_data(self, **kwargs):
        """
        Adds/appends a value to the list given by `label`.

        Args:
            kwargs: Dict of values to be added to data lists corresponding to their labels.
        """
        for key, value in kwargs.items():
            self._data[key].append(value)


    def __repr__(self):

        self_dict = dict({k:v for k,v in self.__dict__.items() if k[0] != "_"})
        pretty_dict = pformat(self_dict)
        
        return self.__class__.__name__ + "(**" + pretty_dict + ")"


### ConGAN Trainer

In [None]:
class ConGANTrainer(BaseTrainer):
    """The trainer for GAN."""

    def __init__(self, model, dataset,
                 D_optim_config={},
                 G_optim_config={},
                 D_iters=5,
                 clamp=(-0.01, 0.01),
                 grad_penalty=10.,
                 noise_std=0.01,
                 generate_grid_interval=200,
                 constants={},
                 **kwargs):
        """
        Initializes GANTrainer.

        Note:
            Optimizer's configurations/parameters must be passable to the
            optimizer (in torch.optim). It should also include a parameter
            `optim_choice` for the choice of the optimizer (e.g. "sgd" or "adam").

        Args:
            model: The model.
            dataset: The dataset.
            D_optim_config: Configurations for the discriminator's optimizer.
            G_optim_config: Configurations for the generator's optimizer.
            D_iters: Number of iterations to train discriminator every batch.
            clamp: Range on which the discriminator's weight will be clamped after each update.
            generate_grid_interval: Check progress every `generate_grid_interval` batch.
        """
        super().__init__(model, dataset, **kwargs)

        self.D_iters = D_iters
        self.clamp = clamp
        self.noise_std = noise_std
        self.generate_grid_interval = generate_grid_interval

        # Initialize optimizers for generator and discriminator
        self.D_optim = init_optim(self.model.D.parameters(), **D_optim_config)
        self.G_optim = init_optim(self.model.G.parameters(), **G_optim_config)

        # TODO: Specify loss type instead (minimax or wasserstein)
        self.D_loss_fn = get_D_loss(self.model.gan_type)
        self.G_loss_fn = get_G_loss(self.model.gan_type)

        # Grad penalty
        self.grad_prenalty = grad_prenalty
        
        # Initialize list of image grids generated from a fixed latent variable
        grid_size = 8 * 8
        self._fixed_latent = torch.randn([grid_size, self.model.num_latents], device=self.device)
        self._generated_grids = []


    def train_step(self):
        """
        Makes one training step.

        GAN:
        Throughout this doc, we will denote a sample from the real data
        distribution, fake data distribution, and latent variables respectively
        as follows:
            x ~ real,    x_g ~ fake,    z ~ latent

        Now recall that in order to train a GAN, we try to find a solution to
        a min-max game of the form `min_G max_D V(G,D)`, where G is the generator,
        D is the discriminator, and V(G,D) is the score function.
        For a regular GAN, V(G,D) = log(D(x)) + log(1 - D(x_g)),
        which is the Jensen-Shannon (JS) divergence between the probability
        distributions P(x) and P(x_g), where P(x_g) is parameterized by G.

        When it comes to Wasserstein GAN (WGAN), the objective is to minimize
        the Wasserstein (or Earth-Mover) distance instead of the JS-divergence.
        See Theorem 3 and Algorithm 1 in the original paper for more details.
        We can achieve that (thanks to the Kantorovich-Rubinstein duality)
        by first maximizing  `D(x) - D(x_g)` in the space of 1-Lipschitz
        discriminators D, where x ~ data and x_g ~ fake.
        Then, we have the gradient wrt G of the Wasserstein distance equal
        to the gradient of -D(G(z)).
        Since we assumed that D should be 1-Lipschitz, we can enforce
        k-Lipschitzness by clamping the weights of D to be in some fixed box,
        which would be approximate up to a scaling factor.

        Enforcing Lipschitzness is done more elegantly in WGAN-GP,
        which is just WGAN with gradient penalty (GP). The gradient penalty
        is used because of the statement that a differentiable function is
        1-Lipschitz iff it has gradient norm equal to 1 almost everywhere
        under P(x) and P(x_g). Hence, the objective will be similar to WGAN,
        which is `min_G max_D of D(x) - D(x_g)`, but now we add the gradient
        penalty in the D_step such that it will be minimized.

        Links to the papers:
        GAN:     https://arxiv.org/pdf/1406.2661.pdf
        WGAN:    https://arxiv.org/pdf/1701.07875.pdf
        WGAN-GP: https://arxiv.org/pdf/1704.00028.pdf
        """

        for _ in range(self.D_iters):
            # Sample real data from the dataset
            sample = self.sample_dataset()
            real = sample["before"].to(self.device)

            # Sample latent and train discriminator
            latent = self.sample_latent()
            D_results = self.D_step(real, latent)

        # Sample latent and train generator
        latent = self.sample_latent()
        G_results = self.G_step(latent)

        # Record data
        self.add_data(**D_results, **G_results)
        losses = {"D_loss": D_results["D_loss"],
                  "G_loss": G_results["G_loss"]}
        D_evals = {"D_on_real": D_results["D_on_real"],
                   "D_on_fake": D_results["D_on_fake2"]}
        self.writer.add_scalars("Loss", losses, self.iters)
        self.writer.add_scalars("D_evals", D_evals, self.iters)


    def D_step(self, real, latent):

        # Add noise to real
        real += torch.randn_like(real) * self.noise_std

        # Sample from generators
        with torch.no_grad():
            fake = self.model.G(latent)
        # Add noise to fake
        fake += torch.randn_like(fake) * self.noise_std

        # Classify real and fake images
        D_on_real = self.model.D(real)
        D_on_fake = self.model.D(fake)

        # Adversarial loss
        adv_loss = self.D_loss_fn(D_on_real, D_on_fake)

        # Gradient penalty
        D_grad_penalty = torch.tensor(0.0)
        if grad_penalty != 0:
            D_grad_penalty = simple_gradient_penalty(
                self.model.D, random_interpolate(real, fake), center=1.0)
        
        # Calculate gradients and minimize loss
        self.D_optim.zero_grad()
        D_loss = adv_loss + grad_penalty * D_grad_penalty
        D_loss.backward()
        self.D_optim.step()

        return {
            "D_on_real": D_on_real.mean().item(),
            "D_on_fake": D_on_fake.mean().item(),
            "D_grad_penalty": D_grad_penalty.item(),
            "D_loss": D_loss.item(),
        }


    def G_step(self, latent):

        # Sample from generators
        fake = self.model.G(latent)
        fake += torch.randn_like(fake) * self.noise_std

        # Classify fake images
        D_on_fake = self.model.D(fake)

        # Adversarial loss
        adv_loss = self.G_loss_fn(D_on_fake)

        # Calculate gradients and minimize loss
        self.G_optim.zero_grad()
        G_loss = adv_loss
        G_loss.backward()
        self.G_optim.step()

        return {
            "D_on_fake2": D_on_fake.mean().item(),
            "G_loss": G_loss.item(),
        }

    def sample_latent(self):
        """
        Samples from the latent space (i.e. input space of the generator).

        Returns:
            Sample from the latent space.
        """

        # Calculate latent size and sample from normal distribution
        latent_size = [self.batch_size, self.model.latent_dim]
        latent = torch.randn(latent_size, device=self.device)

        return latent


    #################### Reporting and Tracking Methods ####################

    def stop(self):
        """
        Stops the trainer and report the result of the experiment.
        """

        losses = {**self.get_data_containing("D_loss"), **self.get_data_containing("G_loss")}
        lines_to_plot = {"Discriminator Evaluations": "D_on",
                         "Gradient Penalty": "grad_penalty",}

        if not self.save_results:
            plot_lines(losses, title="Losses")
            for title, keyword in lines_to_plot.items():
                plot_lines(self.get_data_containing(keyword), title=title)
            return

        # Create experiment directory in the model's directory
        experiment_dir = os.path.join(self.results_dir, self.get_experiment_name())

        # Save model
        model_path = os.path.join(experiment_dir, "model.pt")
        self.save_model(model_path)

        # Plot losses of D and G
        losses_file = os.path.join(experiment_dir, "losses.png")
        plot_lines(losses, filename=losses_file, title="Losses of D and G")

        # Plot evals of D on real and fake data
        evals_file = os.path.join(experiment_dir, "evals.png")
        plot_lines(evals, filename=evals_file, title="Evaluations of D on real and fake data")

        # Create an animation of the generator's progress
        animation_file = os.path.join(experiment_dir, "progress.mp4")
        create_progress_animation(self._generated_grids, animation_file)

        # Write details of experiment
        details_txt = os.path.join(experiment_dir, "repr.txt")
        with open(details_txt, "w") as f:
            f.write(self.__repr__())


    def post_train_step(self):
        """
        The post-training step.
        """
        super().post_train_step()

        should_generate_grid = self.iters % self.generate_grid_interval == 0

        # Check generator's progress by recording its output on a fixed input
        if should_generate_grid:
            grid = generate_grid(self.model.G, self._fixed_latent)
            self._generated_grids.append(grid)
            self.writer.add_image("grid", grid, self.iters)


# Training

## Utils

In [None]:
def get_D_loss(gan_type="gan"):
    if gan_type in ("gan", "gan-gp"):
        return D_loss_GAN
    elif gan_type in ("wgan", "wgan-gp"):
        return D_loss_WGAN
    else:
        raise ValueError(f"gan_type {gan_type} not supported")


def get_G_loss(gan_type="gan"):
    if gan_type in ("gan", "gan-gp"):
        return G_loss_GAN
    elif gan_type in ("wgan", "wgan-gp"):
        return G_loss_WGAN
    else:
        raise ValueError(f"gan_type {gan_type} not supported")


def D_loss_GAN(D_real, D_fake, label_smoothing=True):
    
    # Create (noisy) real and fake labels XXX
    if label_smoothing:
        real_label = 0.7 + 0.5 * torch.rand_like(D_real)
    else:
        real_label = torch.ones_like(D_real) - 0.1
    fake_label = torch.zeros_like(D_fake)

    # Calculate binary cross entropy loss
    D_loss_real = F.binary_cross_entropy(D_real, real_label)
    D_loss_fake = F.binary_cross_entropy(D_fake, fake_label)

    # Loss is: - log(D(x)) - log(1 - D(x_g)),
    # which is equiv. to maximizing: log(D(x)) + log(1 - D(x_g))
    D_loss = D_loss_real + D_loss_fake

    return D_loss.mean()


def D_loss_WGAN(D_real, D_fake):

    # Maximize: D(x) - D(x_g) - const * (|| grad of D(x_i) wrt x_i || - 1)^2,
    # where x_i <- eps * x + (1 - eps) * x_g, and eps ~ rand(0,1)
    D_loss = -1 * (D_real - D_fake)

    return D_loss.mean()


def G_loss_GAN(D_fake):

    # Calculate binary cross entropy loss with a fake binary label
    fake_label = torch.zeros_like(D_fake)

    # Loss is: -log(D(G(z))), which is equiv. to minimizing log(1-D(G(z)))
    # We use this loss vs. the original one for stability only.
    G_loss = F.binary_cross_entropy(D_fake, 1 - fake_label)

    return G_loss.mean()


def G_loss_WGAN(D_fake):

    # Minimize: -D(G(z))
    G_loss = -D_fake
    
    return G_loss.mean()


def interpolate(real, fake):
    eps = torch.rand(real.size(0), 1, 1, 1).to(real)
    return eps * real + (1 - eps) * fake

def simple_gradient_penalty(D, x, center=0.):
    x.requires_grad_()
    D_x = D(x)
    D_grad = torch.autograd.grad(D_x, x, torch.ones_like(D_x), create_graph=True)
    D_grad_norm = D_grad[0].view(x.size(0), -1).norm(dim=1)
    return (D_grad_norm - center).pow(2).mean()


## Prep

In [None]:
class Args:
    def __init(self):
        self.data = "/content/drive/My Drive/gansiam/dataset/"
        self.arch = "resnet50"
        self.workers = 32
        self.epochs = 100
        self.start_epoch = 0
        self.batch_size = 512
        self.learning_rate = 0.05
        self.momentum = 0.9
        self.weight_decay = 1e-4
        self.print_freq = 10
        self.resume = ""
        self.world_size = -1
        self.rank = -1
        self.seed = None
        self.gpu = None
        self.dist_url = None
        self.dist_backend = None
        self.multiprocessing_distributed = False

        # SimSiam
        self.dim = 2048
        self.repr_dim = self.dim
        self.pred_dim = 512
        self.fix_pred_lr = False

        # GAN
        D_iters = 3
        grad_penalty = 10.

# Copied from SimSiam repo with some adjustments >>>>>>>>
import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

def simsiam_main(args):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        torch.distributed.barrier()
    # create model
    print("=> creating model '{}'".format(args.arch))
    # <<<<<<<<<<
    model = ConsistentGAN(models.__dict__[args.arch],
                          args.dim,
                          args.pred_dim,
                          args.latent_dim)
    # >>>>>>>>>>

    # infer learning rate before changing batch size
    init_lr = args.lr * args.batch_size / 256

    if args.distributed:
        # Apply SyncBN
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    print(model) # print model after SyncBatchNorm

    # define loss function (criterion) and optimizer
    criterion = nn.CosineSimilarity(dim=1).cuda(args.gpu)
    # <<<<<<<<<<
    # Define D and G loss functions
    D_criterion = D_loss_WGAN if gan_type == "wgan" else D_loss_GAN
    G_criterion = G_loss_WGAN if gan_type == "wgan" else G_loss_GAN
    # >>>>>>>>>>

    if args.fix_pred_lr:
        optim_params = [{'params': model.module.encoder.parameters(), 'fix_lr': False},
                        {'params': model.module.predictor.parameters(), 'fix_lr': True}]
    else:
        optim_params = model.parameters()

    optimizer = torch.optim.SGD(optim_params, init_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # <<<<<<<<<<
    D_optimizer = torch.optim.SGD(model.module.D.parameters(), init_lr,
                                  momentum=args.momentum,
                                  )#weight_decay=args.weight_decay)
    G_optimizer = torch.optim.SGD(model.module.G.parameters(), init_lr,
                                  momentum=args.momentum,
                                  )#weight_decay=args.weight_decay)
    # >>>>>>>>>>

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]

    train_dataset = datasets.ImageFolder(
        traindir,
        TwoCropsTransform(transforms.Compose(augmentation)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, init_lr, epoch, args)

        # train for one epoch
        # <<<<<<<<<<
        train(train_loader, model,
              criterion, D_criterion, G_criterion,
              optimizer, D_optimizer, G_optimizer,
              epoch, args)
        # >>>>>>>>>>

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
            }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch))



## Train

In [None]:
def train(train_loader, model,
          criterion, D_criterion, G_criterion,
          optimizer, D_optimizer, G_optimizer,
          epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    D_losses = AverageMeter('D Loss', ':.4f')
    G_losses = AverageMeter('G Loss', ':.4f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)
        
        x1 = images[0]
        x2 = images[1]
        batch_size = x1.size(0)

        # compute output and loss
        # Note: repr are detached
        pred1, pred2, repr1, repr2 = model(x1=x1, x2=x2)
        loss = -0.5 * args.siam_coeff * \
            (criterion(pred1, repr2).mean() + criterion(pred2, repr1).mean())
        losses.update(loss.item(), batch_size)
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # train GAN for repr 1
        for _ in range(args.D_iters):
            latent = model.sample_latent(batch_size)
            
            # Add noise to real sample
            real = repr1 + torch.randn_like(repr1) * noise

            # Sample from generator
            with torch.no_grad():
                fake = model.G(latent)
                # Add noise to fake sample as well
                fake += torch.randn_like(fake) * noise

            # Classify real and fake data
            D_real = model.D(real)
            D_fake = model.D(fake)

            # Calculate loss
            D_loss = D_criterion(D_real, D_fake)
            # Gradient penalty
            if grad_penalty != 0:
                D_grad_penalty = simple_gradient_penalty(
                    model.D, interpolate(real, fake), center=1.0)
                D_loss += grad_penalty * D_grad_penalty

            # Calculate gradient and minimize
            D_optimizer.zero_grad()
            D_loss.backward()
            D_optimizer.step()

        # Sample from generators
        latent = model.sample_latent(batch_size)
        fake = model.G(latent)
        fake += torch.randn_like(fake) * args.noise
        # Classify fake images
        D_fake = model.D(fake)
        # Calculate loss
        G_loss = G_criterion(D_fake)
        G_loss += args.recon_coeff * F.mse_loss(fake, repr2)  # note repr2
        # Calculate gradient and minimize
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # do the same for repr2? @TODO
        for _ in range(args.D_iters):
            latent = model.sample_latent(batch_size)
            
            # Add noise to real sample
            real = repr2 + torch.randn_like(repr2) * noise

            # Sample from generator
            with torch.no_grad():
                fake = model.G(latent)
                # Add noise to fake sample as well
                fake += torch.randn_like(fake) * noise

            # Classify real and fake data
            D_real = model.D(real)
            D_fake = model.D(fake)

            # Calculate loss
            D_loss = D_criterion(D_real, D_fake)
            # Gradient penalty
            if grad_penalty != 0:
                D_grad_penalty = simple_gradient_penalty(
                    model.D, interpolate(real, fake), center=1.0)
                D_loss += grad_penalty * D_grad_penalty

            # Calculate gradient and minimize
            D_optimizer.zero_grad()
            D_loss.backward()
            D_optimizer.step()

        # Sample from generators
        latent = model.sample_latent(batch_size)
        fake = model.G(latent)
        fake += torch.randn_like(fake) * args.noise
        # Classify fake images
        D_fake = model.D(fake)
        # Calculate loss
        G_loss = G_criterion(D_fake)
        G_loss += args.recon_coeff * F.mse_loss(fake, repr1)
        # Calculate gradient and minimize
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # Adding second loss only @XXX
        D_losses.update(D_loss.item(), batch_size)
        G_losses.update(G_loss.item(), batch_size)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, init_lr, epoch, args):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    for param_group in optimizer.param_groups:
        if 'fix_lr' in param_group and param_group['fix_lr']:
            param_group['lr'] = init_lr
        else:
            param_group['lr'] = cur_lr