# ECGAN

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_PATH = './data'
RAFDB_PATH = DATA_PATH + '/RAFDB'
MASKED_RAFDB_PATH = DATA_PATH + '/Masked_RAFDB'
BINARY_RAFDB_PATH = DATA_PATH + '/Binary_RAFDB'

## Imports

In [None]:
import numpy as np
import torch
from torch import nn
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Face Inpainting Architecture 

Code for building blocks of generator and discriminators from https://github.com/daviddirethucus/Face-Mask_Inpainting
Below is the original face inpainting architecture from 'A novel gan-based network for unmasking of masked face' by Nizam Ud Din et al.

In [None]:
def crop(image,new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - round(new_shape[2] / 2)
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - round(new_shape[3] / 2)
    final_width = starting_width+new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

class ContractingBlock(nn.Module):
    def __init__(self, input_channels, use_in=True, use_dropout=False):
        super(ContractingBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        if use_in:
            self.insnorm = nn.InstanceNorm2d(input_channels * 2)
        self.use_in = use_in
        if use_dropout:
            self.drop = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self,x):
        x = self.conv(x)
        if self.use_in:
            x = self.insnorm(x)
        if self.use_dropout:
            x = self.drop(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x

class ExpandingBlock(nn.Module):
    def __init__(self,input_channels,use_in=True):
        super(ExpandingBlock, self).__init__()
        self.tconv = nn.ConvTranspose2d(input_channels, input_channels//2,kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv2 = nn.Conv2d(input_channels, input_channels//2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        if use_in:
            self.insnorm = nn.InstanceNorm2d(input_channels//2)
        self.use_in = use_in

    def forward(self, x, skip_x):
        x = self.tconv(x)
        skip_x = crop(skip_x, x.shape)
        x = torch.cat([x, skip_x], axis=1)
        x = self.conv2(x)
        if self.use_in:
            x = self.insnorm(x)
        x = self.activation(x)
        return x

class FeatureMapBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self,x):
        x = self.conv(x)
        return x

class SE_Block(nn.Module):
    def __init__(self,channels,reduction=16):
        super(SE_Block, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    def forward(self,x):
        b, c, _, _ = x.shape
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
    
class AtrousConv(nn.Module):
    def __init__(self,input_channels):
        super(AtrousConv, self).__init__()
        self.aconv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, dilation=2, padding=2)
        self.aconv4 = nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, dilation=4, padding=4)
        self.aconv8 = nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, dilation=8, padding=8)
        self.aconv16 = nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, dilation=16, padding=16)
        self.batchnorm = nn.BatchNorm2d(input_channels)
        self.activation = nn.ReLU()

    def forward(self,x):
        x = self.aconv2(x)
        x = self.batchnorm(x)
        x = self.activation(x)

        x = self.aconv4(x)
        x = self.batchnorm(x)
        x = self.activation(x)

        x = self.aconv8(x)
        x = self.batchnorm(x)
        x = self.activation(x)

        x = self.aconv16(x)
        x = self.batchnorm(x)
        x = self.activation(x)

        return x

## ECGAN Model

In [None]:
class UNetIICGAN(nn.Module):
    def __init__(self, input_channels, output_channels, hidden_channels=32, num_classes=7):
        super(UNetIICGAN, self).__init__()

        self.num_classes = num_classes
        self.label_emb = nn.Embedding(num_classes, input_channels + 1)

        self.upfeature = FeatureMapBlock(input_channels + num_classes, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_in=False, use_dropout=True)
        self.contract2 = ContractingBlock(hidden_channels * 2, use_dropout=True)
        self.contract3 = ContractingBlock(hidden_channels * 4, use_dropout=True)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.contract5 = ContractingBlock(hidden_channels * 16)

        self.atrous_conv = AtrousConv(hidden_channels * 32)

        self.expand0 = ExpandingBlock(hidden_channels * 32)
        self.expand1 = ExpandingBlock(hidden_channels * 16)
        self.expand2 = ExpandingBlock(hidden_channels * 8)
        self.expand3 = ExpandingBlock(hidden_channels * 4)
        self.expand4 = ExpandingBlock(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)

        self.se1 = SE_Block(hidden_channels * 2)
        self.se2 = SE_Block(hidden_channels * 4)
        self.se3 = SE_Block(hidden_channels * 8)

        self.tanh = nn.Tanh()

    def forward(self, x, labels):
        # convert labels to embeddings
        label_emb = self.label_emb(labels).unsqueeze(-1).unsqueeze(-1)

        # concatenate the label embeddings with the input tensor
        x = torch.cat((x, label_emb.repeat(1, 1, x.size(2), x.size(3))), dim=1)

        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x1 = self.se1(x1)
        x2 = self.contract2(x1)
        x2 = self.se2(x2)
        x3 = self.contract3(x2)
        x3 = self.se3(x3)
        x4 = self.contract4(x3)
        x5 = self.contract5(x4)
        x5 = self.atrous_conv(x5)
        x6 = self.expand0(x5, x4)
        x7 = self.expand1(x6, x3)
        x8 = self.expand2(x7, x2)
        x9 = self.expand3(x8, x1)
        x10 = self.expand4(x9, x0)
        xn = self.downfeature(x10)

        return self.tanh(xn)

class Discriminator_whole_CGAN(nn.Module):
    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator_whole_CGAN, self).__init__()
        self.label_emb = nn.Embedding(7, 3)
        self.upfeature = FeatureMapBlock(12, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_in=False)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.final = nn.Conv2d(hidden_channels*16, 1, kernel_size=1)

    def forward(self, x, y, labels):
        # gt, input_imgs, labels
        # compute the label embedding
        label_emb = self.label_emb(labels)
        label_emb = label_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.size(2), x.size(3))

        # concatenate the label embedding and input tensors
        x = torch.cat([x, y, label_emb], dim=1)

        # pass the tensor through the network
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

class Discriminator_mask_CGAN(nn.Module):
    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator_mask_CGAN, self).__init__()
        self.label_emb =nn.Embedding(7, 3)
        self.upfeature = FeatureMapBlock(12, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_in=False)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.final = nn.Conv2d(hidden_channels*16, 1, kernel_size=1)
        self.dropout = nn.Dropout()

    def forward(self, x, y, labels):
        # compute the label embedding
        label_emb = self.label_emb(labels)
        label_emb = label_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.size(2), x.size(3))

        # concatenate the label embedding and input tensors
        x = torch.cat([x, y, label_emb], dim=1)

        # pass the tensor through the network
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x2 = self.dropout(x2)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

def loadm(model, state): # load model pretrained
   model_state_dict = model.state_dict()
   for key in state:
      if  ((key == 'upfeature.conv.weight')) :
        pass
      else:
        model_state_dict[key] = state[key]

   model.load_state_dict(model_state_dict, strict = False)
   return model

def cgan_inpaint_in(model_path):
    lr = 0.0003
    input_dim = 6
    output_dim = 3
    disc_dim = 9

    gen = UNetIICGAN(input_dim, output_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc_whole = Discriminator_whole_CGAN(disc_dim).to(device)
    disc_whole_opt = torch.optim.Adam(disc_whole.parameters(), lr=0.0001)
    disc_mask = Discriminator_mask_CGAN(disc_dim).to(device)
    disc_mask_opt = torch.optim.Adam(disc_mask.parameters(), lr=0.0001)

    loaded_state = torch.load(model_path, map_location=torch.device(device))
    gen = loadm(gen, loaded_state["gen"])
    disc_whole = loadm(disc_whole, loaded_state["disc_whole"])
    disc_mask = loadm(disc_mask, loaded_state["disc_mask"])

    return gen, gen_opt, disc_whole, disc_whole_opt, disc_mask, disc_mask_opt

## Data

### a) Define Dataset

In [None]:
class RAFDataset(Dataset):
    def __init__(
        self,
        unmask_path,
        mask_path,
        binary_path,
        split,
        transform=None
    ):
        """
        Args:
            unmask_path (string): path to non-masked directory
            mask_path (string): path to masked directory
            binary_path (string): path to binary mask directory
            split (string): the target split i.e. 'train' or 'test'
            transform (optional, callable): optional transform to apply on image
        """
        labels_df = pd.read_csv(os.path.join(unmask_path, f"{split}_labels.csv"))
        
        self.transform = transform
        
        # Binary mask transform
        self.binary_tf = transforms.Compose([
            transforms.Resize(size=(48, 48)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        unmask_path = os.path.join(unmask_path, split)
        mask_path = os.path.join(mask_path, split)
        binary_path = os.path.join(binary_path, split)
        
        unmask_all_files = os.listdir(unmask_path)
        mask_all_files = os.listdir(mask_path)
        binary_all_files = os.listdir(binary_path)
        
        
        self.data = []
        self.labels = []
        for i in range(0, len(unmask_all_files)):
            item = (
                os.path.join(unmask_path, unmask_all_files[i]),
                os.path.join(mask_path, mask_all_files[i]),
                os.path.join(binary_path, binary_all_files[i]),
            )
            self.data.append(item)
            label = labels_df.loc[labels_df['image'] == unmask_all_files[i]]['label'].item()
            self.labels.append(label)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        unmasked_img_path, masked_img_path, binary_img_path = self.data[idx]
        label = self.labels[idx]
        
        unmasked_img = Image.open(unmasked_img_path)
        masked_img = Image.open(masked_img_path)
        binary_img = Image.open(binary_img_path).convert('RGB')
        
        binary_img = self.binary_tf(binary_img)

        if self.transform:
            unmasked_img = self.transform(unmasked_img)
            masked_img = self.transform(masked_img)
        
        return unmasked_img, masked_img, binary_img, label

### b) Augmentation and loading data

In [None]:
transform = transforms.Compose([
    transforms.Resize(size=(48, 48)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

In [None]:
train_dataset = RAFDataset(RAFDB_PATH, MASKED_RAFDB_PATH, BINARY_RAFDB_PATH, 'train', transform)
valid_dataset = RAFDataset(RAFDB_PATH, MASKED_RAFDB_PATH, BINARY_RAFDB_PATH, 'test', transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

## Loss functions

To train ECGAN for face inpainting we combine discriminator, adversarial, and reconstruction losses.

In [None]:
from torch.autograd import Variable
from math import exp

def normalize(img):
    return (img - (-1)) / (1 - (-1))
def anti_normalize(img):
    return img * (1 - (-1)) + (-1)

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        ###
        img1 = (img1+1)/2
        img2 = (img2+1)/2
        ###
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def recon_loss(gt,fake,recon_criterion):
    ssim = SSIM()
    ssim_loss = ssim(gt,fake)
    l1_loss = recon_criterion(gt,fake)
    return l1_loss,ssim_loss

# Perceptual Loss

from torchvision.models import vgg19

class PerceptualNet(nn.Module):
    def __init__(self, name = "vgg19", resize=True):
        super(PerceptualNet, self).__init__()
        blocks = []
        blocks.append(vgg19(pretrained=True).features[:4].eval())
        blocks.append(vgg19(pretrained=True).features[4:9].eval())
        blocks.append(vgg19(pretrained=True).features[9:16].eval())
        blocks.append(vgg19(pretrained=True).features[16:23].eval())

        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks).to(device)
        self.transform = torch.nn.functional.interpolate
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)).to(device)
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)).to(device)
        self.resize = resize

    def forward(self, inputs, targets):
        if inputs.shape[1] != 3:
            inputs = inputs.repeat(1, 3, 1, 1)
            targets = targets.repeat(1, 3, 1, 1)
        inputs = (inputs+1)/2
        targets = (targets+1)/2
        if self.resize:
            inputs = self.transform(inputs, mode='bilinear', size=(224, 224), align_corners=False)
            targets = self.transform(targets, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = inputs
        y = targets
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.l1_loss(x, y)
        return loss

def percep_loss(gt,fake):
    percep_net = PerceptualNet()
    return percep_net(gt, fake)

def discwhole_loss_func(disc_whole, gt, mask, binary, label, fake, adv_criterion, lambda_Dwhole):
    input_imgs = torch.cat((mask, binary),1)
    fake_pred = disc_whole(fake.detach(),input_imgs, label)
    gt_pred = disc_whole(gt,input_imgs, label)
    fake_loss = adv_criterion(fake_pred,torch.zeros_like(fake_pred))
    gt_loss = adv_criterion(gt_pred,torch.ones_like(gt_pred))
    return lambda_Dwhole * (fake_loss+gt_loss)/2


def discmask_loss_func(disc_mask, gt, fake, mask, binary, label, adv_criterion, lambda_Dmask):
    nor_mask = normalize(mask)
    nor_binary = normalize(binary)
    nor_fake = normalize(fake)

    oofs = torch.mul(nor_mask,1-nor_binary)
    oops = torch.mul(nor_fake,nor_binary)
    ooo = anti_normalize(oofs+oops)
    input_imgs = torch.cat((mask,binary),1)
    fake_pred = disc_mask(ooo.detach(),input_imgs, label)
    gt_pred = disc_mask(gt,input_imgs, label)

    fake_loss = adv_criterion(fake_pred,torch.zeros_like(fake_pred))
    gt_loss = adv_criterion(gt_pred,torch.ones_like(gt_pred))

    return lambda_Dmask * (fake_loss+gt_loss)/2


def gen_adv_loss(gen,disc, gt, mask,binary, label, adv_criterion):
    input_imgs = torch.cat((mask,binary),1)
    fake = gen(input_imgs, label)
    fake_pred = disc(fake,input_imgs, label)
    adv_loss = adv_criterion(fake_pred,torch.ones_like(fake_pred))
    return adv_loss,fake

def generator_loss(cur_step, gen, disc_whole, disc_mask, gt, mask, binary, label, adv_criterion, recon_criterion, lambda_recon, lambda_adv_whole, lambda_adv_mask):
    if cur_step < 3516 * 6:
        adver_loss_whole,fake = gen_adv_loss(gen,disc_whole,gt,mask,binary,label,adv_criterion)
        l1_loss,ssim_loss = recon_loss(gt,fake,recon_criterion)
        reconstruction_loss = l1_loss * 0.5 + (1 - ssim_loss) * 0.5
        perceptual_loss = percep_loss(gt, fake)
        gen_loss = lambda_recon * (reconstruction_loss + perceptual_loss)+lambda_adv_whole*adver_loss_whole
    else:
        adver_loss_whole,fake = gen_adv_loss(gen, disc_whole, gt, mask, binary,label, adv_criterion)
        adver_loss_mask,fake = gen_adv_loss(gen, disc_mask, gt, mask, binary,label, adv_criterion)
        l1_loss,ssim_loss = recon_loss(gt, fake,recon_criterion)
        reconstruction_loss = l1_loss*0.5 + (1-ssim_loss)*0.5
        perceptual_loss = percep_loss(gt,fake)
        gen_loss = lambda_recon*(reconstruction_loss+perceptual_loss)+lambda_adv_whole*adver_loss_whole+lambda_adv_mask*adver_loss_mask


    return gen_loss,fake,l1_loss,ssim_loss,perceptual_loss

# FID

from torchvision.models import inception_v3
import scipy.linalg

inception_model = inception_v3(pretrained=True)
inception_model.to(device)
inception_model = inception_model.eval() # Evaluation mode
inception_model.fc = torch.nn.Identity()

def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real,device=x.device)

def frechet_distance(mu_x,mu_y,sigma_x,sigma_y):
    return torch.norm(mu_x-mu_y)**2 + torch.trace(sigma_x+sigma_y-2*matrix_sqrt(sigma_x@sigma_y))

def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(),rowvar=False))

## Initialise Model

In [None]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 100
lambda_Dwhole = 0.3
lambda_Dmask = 0.7
lambda_adv_whole = 0.3
lambda_adv_mask = 0.7

num_epochs = 2
input_dim = 6
output_dim = 3
disc_dim = 9
lr = 0.0003

model_path = './models/Inpaint_UNet.pth'
gen, gen_opt, disc_whole, disc_whole_opt, disc_mask, disc_mask_opt = cgan_inpaint_in(model_path)
gen.to(device)
disc_whole.to(device)
disc_mask.to(device)

## Train model

In [None]:
EXPRESSION_MAP = {
    1: 'Surprise',
    2: 'Fear',
    3: 'Disgust',
    4: 'Happy',
    5: 'Sad',
    6: 'Angry',
    7: 'Neutral'
}

def img_to_display(images, display_title, labels=None, num_samples=3):
    """Plot num_sample images
    
    Args:
        images (tensor -> shape(B, C, H, W)): Defines the tensor of images to display
        display_title (string): Defines the title of the figure
        labels (tensor -> shape[B]): The ground truth expression labels for the unmasked images
        num_samples (int): Defines the number of images to sample and display
    
    """

    if num_samples > images.shape[0]:
        assert ValueError("num_samples was greater than the image batch_size")
    
    fig, axes = plt.subplots(nrows=1, ncols=num_samples, figsize=(10, 5), subplot_kw={'xticks': [], 'yticks': []})
    
    for i, ax in enumerate(axes.flat):
        image = images[i].detach().cpu().permute(1, 2, 0)
        image = (image + 1) / 2
        ax.imshow(image)
        
        # Write label as ax title
        if labels != None:
            label = labels[i].detach().cpu().item()
            ax.set_title(EXPRESSION_MAP[label])
    
    fig.suptitle(display_title, y=0.87, size=18)
    plt.tight_layout()
    plt.show()
    
def ecgan_train(gen, disc_whole, disc_mask, disc_whole_opt, disc_mask_opt, train_loader, valid_loader, num_epochs, cur_step=0, display_step=400, save_model=True):
    gen.train()
    disc_whole.train()
    disc_mask.train()

    mean_generator_loss = 0
    mean_disc_whole_loss = 0
    mean_disc_mask_loss = 0
    fake_features_list = []
    real_features_list = []
    
    for epoch in range(0, num_epochs):
        train_loop = tqdm(train_loader, leave=False)
        for gt, mask, binary, labels in train_loop:
            gt = gt.to(device)
            mask = mask.to(device)
            binary = binary.to(device)
            labels = labels.to(device)
            
            with torch.no_grad():
                input_imgs = torch.cat((mask, binary), 1)
                # fake = gen(input_imgs, labels)
            
            # Display unmasked, masked, and binary images at each display_step
            if cur_step % display_step == 0:
                img_to_display(gt, 'Ground truth', labels)
                img_to_display(mask, 'Masked images')
                # img_to_display(fake, 'Fake images')
            
            
            cur_step += 1
            train_loop.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
            break


In [None]:
ecgan_train(
    gen,
    disc_whole,
    disc_mask,
    disc_whole_opt,
    disc_mask_opt,
    train_loader,
    valid_loader,
    num_epochs=num_epochs,
    cur_step=0,
    display_step=400,
    save_model=True
)