## 1. IMPORT RELEVANT LIBRARIES

Please have a look at the **requirements.txt** file for the prerequisite libraries and their versions.<br>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import matplotlib.pyplot as plt

import time
import os


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device object

## 2. CREATE THE ARCHITECTURES FOR MASK GENERATOR AND INPAINTING GENRATOR

* The **Binarization layer** works as follows:<br> In *forward propagation*, rounds each element to the nearest integer.<br>In *backward propagation*, sends back the gradients as is.

* Both the Mask Generator(*Mask_Generator_Net*) and the Inpainting Generator(*Generator_Net*) follow the same architectuer as prescribed in the paper except the first Conv Block and the last Transposed Conv block. Both of them have 64 final channels rather than 128.

**NOTE:** Both these networks are saved in the **Generator.py** file within the ./Module folder.

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.autograd

from tensorboardX import SummaryWriter

class Binarization(torch.autograd.Function):
    @staticmethod
    def forward(ctx,input):
        return torch.round(input)

    @staticmethod
    def backward(ctx,grad_output):
        return grad_output

class Mask_Generator_Net(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        ###### C-Block Stacking ##########
        self.l11 = nn.Conv2d(in_channels, 32, 5, padding='same')
        self.l12 = nn.Conv2d(in_channels, 16, 5, padding='same', dilation=2)
        self.l13 = nn.Conv2d(in_channels, 16, 5, padding='same', dilation=5)
        self.d1 = nn.MaxPool2d(2)
        self.l21 = nn.Conv2d(64, 64, 5, padding='same')
        self.l22 = nn.Conv2d(64, 32, 5, padding='same', dilation=2)
        self.l23 = nn.Conv2d(64, 32, 5, padding='same', dilation=5)
        self.d2 = nn.MaxPool2d(2)
        self.l31 = nn.Conv2d(128, 128, 5, padding='same')
        self.l32 = nn.Conv2d(128, 64, 5, padding='same', dilation=2)
        self.l33 = nn.Conv2d(128, 64, 5, padding='same', dilation=5)
        self.d3 = nn.MaxPool2d(2)
        
        
        self.l41 = nn.Conv2d(256, 256, 5, padding='same')
        self.l42 = nn.Conv2d(256, 128, 5, padding='same', dilation=2)
        self.l43 = nn.Conv2d(256, 128, 5, padding='same', dilation=5)
        
        
        self.u1 = nn.Upsample(scale_factor=2)
        self.tl11 = nn.ConvTranspose2d(256+512, 128, 5, padding=2)
        self.tl12 = nn.ConvTranspose2d(256+512, 64, 5, dilation=2, padding=4)
        self.tl13 = nn.ConvTranspose2d(256+512, 64, 5, dilation=5, padding=10)
        self.u2 = nn.Upsample(scale_factor=2)
        self.tl21 = nn.ConvTranspose2d(128+256, 64, 5, padding=2)
        self.tl22 = nn.ConvTranspose2d(128+256, 32, 5, padding=4, dilation=2)
        self.tl23 = nn.ConvTranspose2d(128+256, 32, 5, padding=10, dilation=5)
        self.u3 = nn.Upsample(scale_factor=2)
        self.tl31 = nn.ConvTranspose2d(64+128, 32, 5, padding=2)
        self.tl32 = nn.ConvTranspose2d(64+128, 16, 5, padding=4, dilation=2)
        self.tl33 = nn.ConvTranspose2d(64+128, 16, 5, padding=10, dilation=5)
        ###### Conv Layer Stacking ##########
        self.conv1 = nn.ConvTranspose2d(64+in_channels,8, 3, padding=1)
        self.conv2 = nn.ConvTranspose2d(8, 1, 3, padding=1)
        self.mask = Binarization.apply

    def forward(self, x):
        #### Conv Block 1
        temp1 = F.elu(self.l11(x))
        temp2 = F.elu(self.l12(x))
        temp3 = F.elu(self.l13(x))
        x1o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.d1(x1o)
        
        #### Conv Block 2
        temp1 = F.elu(self.l21(out))
        temp2 = F.elu(self.l22(out))
        temp3 = F.elu(self.l23(out))
        x2o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.d2(x2o)
        
        #### Conv Block 3
        temp1 = F.elu(self.l31(out))
        temp2 = F.elu(self.l32(out))
        temp3 = F.elu(self.l33(out))
        x3o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.d3(x3o)

        #### Conv Block 3
        temp1 = F.elu(self.l41(out))
        temp2 = F.elu(self.l42(out))
        temp3 = F.elu(self.l43(out))
        x4o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.u1(x4o) 
        
        #### Tr-Conv Block 2
        out = torch.cat([out, x3o], dim=1)
        temp1 = F.elu(self.tl11(out))
        temp2 = F.elu(self.tl12(out))
        temp3 = F.elu(self.tl13(out))
        out = torch.cat([temp1, temp2, temp3], dim=1)
        out = self.u2(out)
        
        #### Tr-Conv Block 3
        out = torch.cat([out, x2o], dim=1)
        temp1 = F.elu(self.tl21(out))
        temp2 = F.elu(self.tl22(out))
        temp3 = F.elu(self.tl23(out))
        out = torch.cat([temp1, temp2, temp3], dim=1)
        out = self.u3(out)
        
        #### Tr-Conv Block 4
        out = torch.cat([out, x1o], dim=1)
        temp1 = F.elu(self.tl31(out))
        temp2 = F.elu(self.tl32(out))
        temp3 = F.elu(self.tl33(out))
        out = torch.cat([temp1, temp2, temp3, x], dim=1)
        
        #### Conv Layer Head 1
        out = F.elu(self.conv1(out))
        
        out = F.hardsigmoid(self.conv2(out))
        
        out = self.mask(out)
        return out

class Generator_Net(nn.Module):
    def __init__(self, in_channels=7):
        super().__init__()
        ###### C-Block Stacking ##########
        self.l11 = nn.Conv2d(in_channels, 32, 5, padding='same')
        self.l12 = nn.Conv2d(in_channels, 16, 5, padding='same', dilation=2)
        self.l13 = nn.Conv2d(in_channels, 16, 5, padding='same', dilation=5)
        self.d1 = nn.MaxPool2d(2)
        self.l21 = nn.Conv2d(64, 64, 5, padding='same')
        self.l22 = nn.Conv2d(64, 32, 5, padding='same', dilation=2)
        self.l23 = nn.Conv2d(64, 32, 5, padding='same', dilation=5)
        self.d2 = nn.MaxPool2d(2)
        self.l31 = nn.Conv2d(128, 128, 5, padding='same')
        self.l32 = nn.Conv2d(128, 64, 5, padding='same', dilation=2)
        self.l33 = nn.Conv2d(128, 64, 5, padding='same', dilation=5)
        self.d3 = nn.MaxPool2d(2)
        
        
        self.l41 = nn.Conv2d(256, 256, 5, padding='same')
        self.l42 = nn.Conv2d(256, 128, 5, padding='same', dilation=2)
        self.l43 = nn.Conv2d(256, 128, 5, padding='same', dilation=5)
        
        
        self.u1 = nn.Upsample(scale_factor=2)
        self.tl11 = nn.ConvTranspose2d(256+512, 128, 5, padding=2)
        self.tl12 = nn.ConvTranspose2d(256+512, 64, 5, dilation=2, padding=4)
        self.tl13 = nn.ConvTranspose2d(256+512, 64, 5, dilation=5, padding=10)
        self.u2 = nn.Upsample(scale_factor=2)
        self.tl21 = nn.ConvTranspose2d(128+256, 64, 5, padding=2)
        self.tl22 = nn.ConvTranspose2d(128+256, 32, 5, padding=4, dilation=2)
        self.tl23 = nn.ConvTranspose2d(128+256, 32, 5, padding=10, dilation=5)
        self.u3 = nn.Upsample(scale_factor=2)
        self.tl31 = nn.ConvTranspose2d(64+128, 32, 5, padding=2)
        self.tl32 = nn.ConvTranspose2d(64+128, 16, 5, padding=4, dilation=2)
        self.tl33 = nn.ConvTranspose2d(64+128, 16, 5, padding=10, dilation=5)
        ###### Conv Layer Stacking ##########
        self.conv1 = nn.ConvTranspose2d(64+in_channels,8, 3, padding=1)
        self.conv2 = nn.ConvTranspose2d(8, 3, 3, padding=1)

    def forward(self, x):
        #### Conv Block 1
        temp1 = F.elu(self.l11(x))
        temp2 = F.elu(self.l12(x))
        temp3 = F.elu(self.l13(x))
        x1o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.d1(x1o)
        
        #### Conv Block 2
        temp1 = F.elu(self.l21(out))
        temp2 = F.elu(self.l22(out))
        temp3 = F.elu(self.l23(out))
        x2o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.d2(x2o)
        
        #### Conv Block 3
        temp1 = F.elu(self.l31(out))
        temp2 = F.elu(self.l32(out))
        temp3 = F.elu(self.l33(out))
        x3o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.d3(x3o)

        #### Conv Block 3
        temp1 = F.elu(self.l41(out))
        temp2 = F.elu(self.l42(out))
        temp3 = F.elu(self.l43(out))
        x4o = torch.cat([temp1,temp2, temp3], dim=1)
        out = self.u1(x4o) 
        
        #### Tr-Conv Block 2
        out = torch.cat([out, x3o], dim=1)
        temp1 = F.elu(self.tl11(out))
        temp2 = F.elu(self.tl12(out))
        temp3 = F.elu(self.tl13(out))
        out = torch.cat([temp1, temp2, temp3], dim=1)
        out = self.u2(out)
        
        #### Tr-Conv Block 3
        out = torch.cat([out, x2o], dim=1)
        temp1 = F.elu(self.tl21(out))
        temp2 = F.elu(self.tl22(out))
        temp3 = F.elu(self.tl23(out))
        out = torch.cat([temp1, temp2, temp3], dim=1)
        out = self.u3(out)
        
        #### Tr-Conv Block 4
        out = torch.cat([out, x1o], dim=1)
        temp1 = F.elu(self.tl31(out))
        temp2 = F.elu(self.tl32(out))
        temp3 = F.elu(self.tl33(out))
        out = torch.cat([temp1, temp2, temp3, x], dim=1)
        
        #### Conv Layer Head 1
        out = F.elu(self.conv1(out))
        
        out = F.hardsigmoid(self.conv2(out))
        return out


## 3. CREATE THE ARCHITECTURE FOR DISCRIMINATOR

WGAN network was prescribed in the paper for stabilising the GAN training. There are multiple ways of doing this. In this implementation I used Spectral Norm as the way to enfore the Lipschitz continuity.

**THE CODE FOR THE SPECTRAL NORM IS BORROWED FROM** [DVD-GAN repository by Harrypotterrrr](https://github.com/Harrypotterrrr/DVD-GAN).<br>

**Architecture details:** I used a funnel network with 4 conv blocks. Each block contains a conv layer, whose weights are normalized using spectral norm, followed by a Leaky ReLU activation and a downsampling with factor 2. We used 64,128,256,512 channels respectively for each conv block. We then flatten the embedding and used a fully connected layer to get the discriminator score.

**NOTE:** The network is saved in the **Discriminators.py** file within the ./Module folder.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import init

from Module.Normalization import  SpectralNorm

class Discriminator(nn.Module):
    def __init__(self, in_channels=4):
        super().__init__()
        self.pre_conv = nn.Sequential(SpectralNorm(nn.Conv2d(in_channels,64,3, padding='same')),
                                      nn.LeakyReLU(),
                                      nn.MaxPool2d(2),
                                       SpectralNorm(nn.Conv2d(64,128,3, padding='same')), 
                                       nn.LeakyReLU(),
                                      nn.MaxPool2d(2),
                                      SpectralNorm(nn.Conv2d(128,256,3, padding='same')),
                                       nn.LeakyReLU(),
                                      nn.MaxPool2d(2),
                                      SpectralNorm(nn.Conv2d(256,512,3, padding='same')),
                                       nn.LeakyReLU(),
                                      nn.MaxPool2d(2)
                                      )
        self.linear = SpectralNorm(nn.Linear(512*8*8, 1))
    def forward(self,x):
        out = self.pre_conv(x)
        B,C,H,W = out.shape
        out = out.view(B,-1)
        out = self.linear(out)
        return out

## 4. REUSABLE UTILITIES

In [4]:
# Wasserstein Generator Loss
def wass_g_loss(fake_score):
    return torch.mean(fake_score)

# Wasserstein Disciminator Loss
def wass_d_loss(real_score, fake_score):
    return -1.0*torch.mean(fake_score) + torch.mean(real_score)

# Mask density loss to encourage selecting only wanted_density% of the total pixels. L2 loss
def mask_density_loss(mask, wanted_density):    
    density = torch.div(torch.sum(mask), (mask.shape[0]*128*128))
    return torch.pow((density - wanted_density), 2)
    

# Image reconstruction Loss. L1 loss used in accordance with the paper.
def image_loss(real_image, fake_image):
    return F.l1_loss(fake_image, real_image, reduce='mean')

# Write TensorboardX logs for full joint training scenario
def write_log_full(writer, log_str, step, d_loss_real, d_loss_fake, ds_loss, density_loss, image_loss):

    writer.add_scalar('data/ds_loss_real', d_loss_real.item(), step)
    writer.add_scalar('data/ds_loss_fake', d_loss_fake.item(), step)
    writer.add_scalar('data/ds_loss', ds_loss.item(), step)
    writer.add_scalar('data/density_loss', density_loss.item(), step)
    writer.add_scalar('data/image_reconstruction_loss', image_loss.item(), step)

    writer.add_text('logs', log_str, step)
    
# Write TensorboardX logs for No mask network scenario
def write_log_no_mask(writer, log_str, step, d_loss_real, d_loss_fake, ds_loss, image_loss):

    writer.add_scalar('data/ds_loss_real', d_loss_real.item(), step)
    writer.add_scalar('data/ds_loss_fake', d_loss_fake.item(), step)
    writer.add_scalar('data/ds_loss', ds_loss.item(), step)
    writer.add_scalar('data/image_reconstruction_loss', image_loss.item(), step)

    writer.add_text('logs', log_str, step)
    
# Set target device/devices for training based on availability.
def set_device(config):

    if config.gpus == "": # cpu
        return 'cpu', False, ""
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(config.gpus)

        if torch.cuda.is_available() is False: # cpu
            return 'cpu', False, ""
        else:
            gpus = list(range(len(config.gpus)))
            if config.parallel is True and len(gpus) > 1: # multi gpus
                return 'cuda:0', True, gpus
            else:
                return 'cuda:'+ str(gpus[0]), False, gpus


## 5. TRAINER CLASSES FOR DIFFERENT TRAINING SCENARIOS

For the sake of experimentation we are training models for two different scenarios.
1. **Full Model**: All models(Mask_Generator_Net + Geneator_Net + Discriminator) are jointly trained. **Trainer** class
2. **No Mask N/w**: No Mask_Generator_Net. Instead, masks of target density are sampled. The Generator_Net and Discriminator used normally. **Trainer_NoMask** class

The structure of both the Trainer classes is motivated by trainer.py file of [DVD-GAN repository by Harrypotterrrr](https://github.com/Harrypotterrrr/DVD-GAN)

In [5]:
import time
import torch
import datetime

import torch.nn as nn
from torchvision.utils import save_image, make_grid

class Trainer(object):
    def __init__(self, data_loader, config):
        '''Initialize the hyperparameters and paths for training. Initialize models/load pretrained models and create optimizers'''

        # Data loader
        self.data_loader = data_loader

        self.total_epoch = config.total_epoch
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.pretrained_model = config.pretrained_model

        self.use_tensorboard = config.use_tensorboard
        self.density = config.density
        self.d_iters = config.d_iters

        # path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path

        # epoch size
        self.log_epoch = config.log_epoch
        self.sample_epoch = config.sample_epoch
        self.model_save_epoch = config.model_save_epoch
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path, self.version)

        self.device, self.parallel, self.gpus = set_device(config)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            print('load_pretrained_model...')
            self.load_pretrained_model()


    def select_opt_schr(self):
        '''Initialize the optimizers for the two Generator networks and the Discriminator network'''

        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, list(self.G.parameters())+list(self.mask_G.parameters())), self.g_lr,
                                            eps=1e-07)
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
                                            eps=1e-07)

    def epoch2step(self):
        '''Convert epochs into number of steps based on the size of the dataloader'''

        self.epoch = 0
        step_per_epoch = len(self.data_loader)
        print("steps per epoch:", step_per_epoch)

        self.total_step = self.total_epoch * step_per_epoch
        self.log_step = self.log_epoch * step_per_epoch
        self.sample_step = self.sample_epoch * step_per_epoch
        self.model_save_step = self.model_save_epoch * step_per_epoch

    def train(self):

        # Data iterator
        print("Inside the trainer!")
        data_iter = iter(self.data_loader)
        print("Iterator created!")
        self.epoch2step()

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 1

        # Start time
        print("=" * 30, "\nStart training...")
        start_time = time.time()

        self.D.train()
        self.G.train()
        self.mask_G.train()

        for step in range(start, self.total_step+1):

            try:
                real_imgs, _ = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                real_imgs, _ = next(data_iter)
                self.epoch += 1

            real_imgs = real_imgs.to(self.device)

            # ================ update D d_iters times ================ #
            for i in range(self.d_iters):
                
                # ============== Genrate Masks =================== #
                img_masks = self.mask_G(torch.cat((real_imgs, torch.normal(0.0,0.1,size=real_imgs.size()).to(device)), 1))
                
                # ============== Create Masked images ============== #
                
                masked_imgs = img_masks * real_imgs
                
                # ============== Generator - Image reconstruction ================= #
                
                fake_imgs = self.G(torch.cat((masked_imgs, img_masks, torch.normal(0.0,0.1,size=real_imgs.size()).to(device)), 1))              
                #print("Inpainted Video Generated")
                
                fake_input = torch.cat((fake_imgs, img_masks), 1)
                real_input = torch.cat((real_imgs, img_masks), 1)


        # ============== Calculate losses and update the networks =========== #                
                    
                d_fake = self.D(fake_input)

                # ============  Update the two Generator networks (Wasserstein Generator loss + Image Reconstruction loss + Mask Density loss)
                g_loss = 0.005*wass_g_loss(d_fake)
                mse_loss = image_loss(fake_imgs, real_imgs)
                density_loss = mask_density_loss(img_masks, self.density)

                g_combined_loss = g_loss + mse_loss + 100* density_loss
                self.reset_grad()
                g_combined_loss.backward()
                self.g_optimizer.step()

                # ============ Update the Discriminator network (Wasserstein Discriminator loss)
                d_fake = self.D(fake_input.detach())
                d_real = self.D(real_input.detach())
                d_loss = wass_d_loss(d_real, d_fake)
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                
            # ==================== print & save part ==================== #
            # Print out log info
            if step % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                start_time = time.time()
                log_str = "Epoch: [%d/%d], Step: [%d/%d], time: %s, d_real: %.4f, d_fake: %.4f, d_loss: %.4f, g_loss: %.4f, density_loss: %.4f, image_loss: %.4f" % \
                    (self.epoch+1, self.total_epoch, step, self.total_step, elapsed, torch.mean(d_real), torch.mean(d_fake), d_loss, g_loss, density_loss, mse_loss)

                if self.use_tensorboard is True:
                    write_log_full(self.writer, log_str, step, torch.mean(d_real), torch.mean(d_fake), d_loss, density_loss, mse_loss)
                print(log_str)

            # Save model
            if step % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                           os.path.join(self.model_save_path, '{}_G.pth'.format(step)))
                torch.save(self.mask_G.state_dict(),
                           os.path.join(self.model_save_path, '{}_mask_G.pth'.format(step)))
                torch.save(self.D.state_dict(),
                           os.path.join(self.model_save_path, '{}_D.pth'.format(step)))

    def build_model(self):
        '''Initialize all three networks along with their optimizers and load them onto relevant device'''

        print("=" * 30, '\nBuild_model...')
        self.mask_G = Mask_Generator_Net(6).to(self.device)
        self.G = Generator_Net(7).to(self.device)
        self.D = Discriminator(4).to(self.device)

        if self.parallel:
            print('Use parallel...')
            print('gpus:', os.environ["CUDA_VISIBLE_DEVICES"])
            self.mask_G = nn.DataParallel(self.mask_G, device_ids=self.gpus)
            self.G = nn.DataParallel(self.G, device_ids=self.gpus)
            self.D = nn.DataParallel(self.D, device_ids=self.gpus)

        self.select_opt_schr()
        print("Model building done!")

    def build_tensorboard(self):
        '''Initialize TensorboardX summary writer'''
        
        from tensorboardX import SummaryWriter
        self.writer = SummaryWriter(log_dir=self.log_path)

    def load_pretrained_model(self):
        '''Load pretrained models of all three networks from self.model_save_path folder'''
        
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.mask_G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_mask_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def reset_grad(self):
        '''Reset the gradients before backward propagation of losses'''
        
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        '''Save relevant samples at the current step'''
        
        real_images, _ = next(data_iter)
        save_image(real_images, os.path.join(self.sample_path, 'real.png'))


In [6]:
import time
import torch
import datetime

import torch.nn as nn
from torchvision.utils import save_image, make_grid

class Trainer_NoMask(object):
    def __init__(self, data_loader, config):
        '''Initialize the hyperparameters and paths for training. Initialize models/load pretrained models and create optimizers'''

        # Data loader
        self.data_loader = data_loader

        self.total_epoch = config.total_epoch
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.pretrained_model = config.pretrained_model

        self.use_tensorboard = config.use_tensorboard
        self.density = config.density
        self.d_iters = config.d_iters

        # path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path

        # epoch size
        self.log_epoch = config.log_epoch
        self.sample_epoch = config.sample_epoch
        self.model_save_epoch = config.model_save_epoch
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path, self.version)

        self.device, self.parallel, self.gpus = set_device(config)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            print('load_pretrained_model...')
            self.load_pretrained_model()


    def select_opt_schr(self):
        '''Initialize the optimizers for the Generator network(no Mask_Generator_Net) and the Discriminator network'''

        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
                                            eps=1e-07)
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
                                            eps=1e-07)

    def epoch2step(self):
        '''Convert epochs into number of steps based on the size of the dataloader'''

        self.epoch = 0
        step_per_epoch = len(self.data_loader)
        print("steps per epoch:", step_per_epoch)

        self.total_step = self.total_epoch * step_per_epoch
        self.log_step = self.log_epoch * step_per_epoch
        self.sample_step = self.sample_epoch * step_per_epoch
        self.model_save_step = self.model_save_epoch * step_per_epoch

    def train(self):

        # Data iterator
        print("Inside the trainer!")
        data_iter = iter(self.data_loader)
        print("Iterator created!")
        self.epoch2step()

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 1

        # Start time
        print("=" * 30, "\nStart training...")
        start_time = time.time()

        self.D.train()
        self.G.train()

        for step in range(start, self.total_step+1):

            try:
                real_imgs, _ = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                real_imgs, _ = next(data_iter)
                self.epoch += 1

            real_imgs = real_imgs.to(self.device)

            # ================ update D d_iters times ================ #
            for i in range(self.d_iters):
                
                # ============== Genrate random Masks of target density from a uniform random distribution=================== #
                
                B,C,H,W = real_imgs.shape
                img_masks = ((torch.rand(B,1,H,W)<=self.density)*1.0).to(device)
                
                # ============== Create Masked images ============== #
                
                masked_imgs = img_masks * real_imgs
                
                # ============== Generator - Image reconstruction ================= #
                
                fake_imgs = self.G(torch.cat((masked_imgs, img_masks, torch.normal(0.0,0.1,size=real_imgs.size()).to(device)), 1))              
                
                fake_input = torch.cat((fake_imgs, img_masks), 1)
                real_input = torch.cat((real_imgs, img_masks), 1)


        # ============== Calculate losses and update the networks =========== #                
                    
                d_fake = self.D(fake_input)
                
                # ============  Update the Generator_Net network (Wasserstein Generator loss + Image Reconstruction loss)

                g_loss = 0.005*wass_g_loss(d_fake)
                mse_loss = image_loss(fake_imgs, real_imgs)

                g_combined_loss = g_loss + mse_loss # No density Loss
                self.reset_grad()
                g_combined_loss.backward()
                self.g_optimizer.step()
                
                # ============ Update the Discriminator network (Wasserstein Discriminator loss)

                d_fake = self.D(fake_input.detach())
                d_real = self.D(real_input.detach())
                d_loss = wass_d_loss(d_real, d_fake)
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                
            # ==================== print & save part ==================== #
            # Print out log info
            if step % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                start_time = time.time()
                log_str = "Epoch: [%d/%d], Step: [%d/%d], time: %s, d_real: %.4f, d_fake: %.4f, d_loss: %.4f, g_loss: %.4f, image_loss: %.4f" % \
                    (self.epoch+1, self.total_epoch, step, self.total_step, elapsed, torch.mean(d_real), torch.mean(d_fake), d_loss, g_loss, mse_loss)

                if self.use_tensorboard is True:
                    write_log_no_mask(self.writer, log_str, step, torch.mean(d_real), torch.mean(d_fake), d_loss, mse_loss)
                print(log_str)

            # Save model
            if step % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                           os.path.join(self.model_save_path, '{}_G.pth'.format(step)))
                torch.save(self.D.state_dict(),
                           os.path.join(self.model_save_path, '{}_D.pth'.format(step)))

    def build_model(self):
        '''Initialize Generator_Net + Discriminator along with their optimizers and load them onto relevant device'''

        print("=" * 30, '\nBuild_model...')
        self.G = Generator_Net(7).to(self.device)
        self.D = Discriminator(4).to(self.device)

        if self.parallel:
            print('Use parallel...')
            print('gpus:', os.environ["CUDA_VISIBLE_DEVICES"])
            self.G = nn.DataParallel(self.G, device_ids=self.gpus)
            self.D = nn.DataParallel(self.D, device_ids=self.gpus)

        self.select_opt_schr()
        print("Model building done!")

    def build_tensorboard(self):
        '''Initialize TensorboardX summary writer'''
        
        from tensorboardX import SummaryWriter
        self.writer = SummaryWriter(log_dir=self.log_path)

    def load_pretrained_model(self):
        '''Load pretrained models of Generator_Net and Discriminator networks from self.model_save_path folder'''
        
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def reset_grad(self):
        '''Reset the gradients before backward propagation of losses'''
        
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        '''Save relevant samples at the current step'''
        
        real_images, _ = next(data_iter)
        save_image(real_images, os.path.join(self.sample_path, 'real.png'))


## 6. SET HYPERPARAMETERS FOR TRAINING

The following are the hyperparameters and their values that were used across the training scenarios. Please make changes to the default values if necessary before starting the training.

**Note:** Please change the paths to fit the OS file system. The following paths work on Windows.

In [7]:
class CONFIG():
    def __init__(self):
        
        # Training setting
        self.total_epoch=200                                       # Total number of training epochs
        self.num_workers=2                                         # Number of dataloader workers
        self.g_lr=5e-5                                             # Learning rate for Generator networks(Both Mask_Generator_Net and Generator_Net)
        self.d_lr=5e-5                                             # Learning rate for the Discriminator network
        self.batch_size=16                                         # Batch size for each epoch
        self.d_iters=1                                             # Number of iterations for the Discriminator update step

            # using pretrained
        self.pretrained_model = False                              # Step number of the pretrained models to resume training from there

            # Misc
        self.parallel=True                                         # Use multiple GPUs if they are available 
        self.gpus=["0"]                                            # List of GPUs to use
        self.use_tensorboard=True                                  # Create tensorboardX logs

            # Paths
        self.data_dir = '.\\CelebA_HQ_facial_identity_dataset'     # Path of the root folder of dataset
        self.log_path=".\\output_dlvc\\logs"                       # Path of the folder to save tensorboard logs
        self.model_save_path=".\\outputs_dlvc\\modelsZ"            # Path of the folder to save trained models
        self.sample_path=".\\outputs\\samples"                     # Path to save image samples. But code not written for it.   

            # epoch size
        self.log_epoch=1                                           # Create a tensorboardX log 
        self.sample_epoch=10                                       # Save an image sample with current models. But code not written for it.
        self.model_save_epoch=1                                    # Save new models for each network being trained

        self.density = 0.05                                        # Target density of pixels to be selected for each mask.
        self.version = ""                                          # Version of the run

In [8]:
config = CONFIG()

## 7. PREPARE DATASET

We use **CELEBA HQ Facial Identity dataset** in our experiments.
<table>
    <tr>
        <td> <img src="./Test_images/Celebrity1.jpg" width="200" /> </td>
        <td> <img src="./Test_images/Celebrity2.jpg" width="200" /> </td>
        <td> <img src="./Test_images/Celebrity3.jpg" width="200" /> </td>
    </tr>
</table>


Please download the dataset from the following link: [CELEBA_HQ_facial_identity_dataset](https://postechackr-my.sharepoint.com/:u:/g/personal/dongbinna_postech_ac_kr/ES-jbCNC6mNHhCyR4Nl1QpYBlxVOJ5YiVerhDpzmoS9ezA?download=1) and unzip it. Please don't forget to provide the path of the root directory to the data_dir parameter of Config object.

If you are using Linux, you could use the following commands:<br>
1. wget https://postechackr-my.sharepoint.com/:u:/g/personal/dongbinna_postech_ac_kr/ES-jbCNC6mNHhCyR4Nl1QpYBlxVOJ5YiVerhDpzmoS9ezA?download=1 -O CelebA_HQ_facial_identity_dataset.zip
2. unzip CelebA_HQ_facial_identity_dataset.zip -d ./CelebA_HQ_facial_identity_dataset

Each image of the dataset has the resolution 1024x1024. For the experiments we resize them to 128x128. We mirror the images horizontally as a data augmentation method.

In [9]:
transforms_train = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(), # data augmentation
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(os.path.join(config.data_dir, 'train'), transforms_train)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)

### CREATE FOLDERS FOR SAVING MODELS AND LOGS

In [10]:
def make_folder(path, version):
    if not os.path.exists(os.path.join(path, version)):
        os.makedirs(os.path.join(path, version))

make_folder(config.model_save_path, config.version)
make_folder(config.log_path, config.version)

## 8. TRAINING

1. Initialize the corresponding traininer object of the train scenarios. Please uncomment the relevant line of code
2. Call the train function of the trainer.

**NOTE:** The complete training was done on Google Colab. Therefore, the outputs aren't visible for the training scenario. To show that the code works I ran 2 epochs.

In [None]:
# config.total_epoch=200
config.total_epoch=2

# 1. To trian the Full model uncomment this and comment the subsequent one
trainer = Trainer(train_dataloader, config) 

# 2. To train the model with randomly generated masks(No Mask network)
# trainer = Trainer_NoMask(train_dataloader, config) 

# Train the model
trainer.train()

Build_model...
Model building done!
Inside the trainer!
Iterator created!
steps per epoch: 267
Start training...




Epoch: [1/2], Step: [267/534], time: 0:05:00.686826, d_real: -31.9983, d_fake: -27.4074, d_loss: -4.5909, g_loss: -0.1374, density_loss: 0.0000, image_loss: 0.1527


## FINAL MODELS

The final models of both the training scenarios are saved in the DLVC_BestModels folder. Within it:
1. **Full Model**: Both generator models saved in "./DLVC_BestModels/Full_Training" folder
2. **No Mask N/w**: The Generator_Net model saved in "./DLVC_BestModels/No_Mask"