# Colab-DFDNet

Official repo: [csxmli2016/DFDNet](https://github.com/csxmli2016/DFDNet)

Original repo: [max-vasyuk/DFDNet](https://github.com/max-vasyuk/DFDNet)

My fork: [styler00dollar/Colab-DFDNet](https://github.com/styler00dollar/Colab-DFDNet)

This Colab does focus on training DFDNet models and using them afterwards. If you just want to try DFDNet, then this isn't the Colab you are looking for. There are multiple different Colabs that already do that and can be found [here](https://github.com/styler00dollar/dl-colab-notebooks). Simply using the original pre-train with ``inference.py`` will result in weight errors.

In [None]:
!nvidia-smi

In [None]:
#@title install
!git clone https://github.com/max-vasyuk/DFDNet
%cd /content/DFDNet
!pip install -r requirements.txt
!pip install pytorch-msssim
!pip install trains
!pip install PyJWT==1.7.1
!pip install tensorboardX

In [None]:
#@title download data
!mkdir /content/DFDNet/DictionaryCenter512
%cd /content/DFDNet/DictionaryCenter512
!gdown --id 1sEB9j3s7Wj9aqPai1NF-MR7B-c0zfTin
!gdown --id 1H4kByBiVmZuS9TbrWUR5uSNY770Goid6
!gdown --id 10ctK3d9znZ9nGN3d1Z77xW3GGshbeKBb
!gdown --id 1gcwmrIZjPFVu-cHjdQD6P4luohkPsil-
!gdown --id 1rJ8cORPxbJsIVAiNrjBag0ihaY_Mvurn
!gdown --id 1LkfJv2a3ud-mefAc1eZMJuINuNdSYgYO
!gdown --id 1LH-nxD__icSJvTiAbXAXDch03oDtbpkZ
!gdown --id 1JRTStLFsQ8dwaQjQ8qG5fNyrOvo6Tcvd
!gdown --id 1Z4AkU1pOYTYpdbfljCgNMmPilhdEd0Kl
!gdown --id 1Z4e1ltB3ACbYKzkoMBuVtzZ7a310G4xc
!gdown --id 1fqWmi6-8ZQzUtZTp9UH4hyom7n4nl8aZ
!gdown --id 1wfHtsExLvSgfH_EWtCPjTF5xsw3YyvjC
!gdown --id 1Jr3Luf6tmcdKANcSLzvt0sjXr0QUIQ2g
!gdown --id 1sPd4_IMYgqGLol0gqhHjBedKKxFAxswR
!gdown --id 1eVFjXJRnBH4mx7ZbAmZRwVXZNUbgCQec
!gdown --id 1w0GfO_KY775ZVF3KMk74ya6QL_bNU4cJ
%cd /content/DFDNet
!gdown --id 1VE5tnOKcfL6MoV839IVCCw5FhJxIgml5
!7z x data.zip
!mkdir /content/DFDNet/weights/
%cd /content/DFDNet/weights/
!gdown --id 1SfKKZJduOGhDD27Xl01yDx0-YSEkL2Aa

# Training

The saving frequency can be adjusted by changing the value insde ``if i % 10 == 0 and i != 0:`` in ``run.py``.

In [None]:
#@title run.py (removing trains, printing instead, saving training images locally, tensorboard)
%%writefile /content/DFDNet/run.py
import os
import numpy as np
import cv2

import math
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch
import random
from skimage import transform as trans
from skimage import io
import sys
sys.path.append('FaceLandmarkDetection')
import face_alignment
import dlib
#import logging

from models import *
from models.model_resnet import MultiScaleDiscriminator
from data.custom_dataset import AlignedDataset
from torch.utils.data import DataLoader
import torchvision

from util.Loss import hinge_loss, hinge_loss_G
from pytorch_msssim import MS_SSIM

from torch.utils.tensorboard import SummaryWriter
#from trains import Task
from torchvision.utils import make_grid
from tqdm import tqdm

from util.losses import LossNetworkVgg19
from torchvision.utils import save_image

from tensorboardX import SummaryWriter
logdir='/content/DFDNet/logging'
writer = SummaryWriter(logdir=logdir)

#logger = logging.getLogger('logging')

def tensor2im(input_image, norm=1, imtype=np.uint8):
    image_numpy = input_image.data.cpu().float().clamp_(-1,1).numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def train(device, dataset, trainloader, netG, netD, writer):
    optimizerG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    optimizerD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    
    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    criterionG = torch.nn.MSELoss()
    criterionD = torch.nn.BCELoss()
#     criterionD = torch.nn.MSELoss()
#     criterionD = hinge_loss()
    hinge_G = hinge_loss_G()
    ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)
    
    weights_layer_perceptual = [0.5, 1., 2., 4.]
    vgg19_model = torchvision.models.vgg19(pretrained=True)
    vgg19_model.to(device)
    
    perceptual_loss_vgg19 = LossNetworkVgg19(vgg19_model)
    perceptual_loss_vgg19.to(device)
    perceptual_loss_vgg19.eval()
    del vgg19_model

    num_epochs = 15

    for epoch in range(num_epochs):
        mean_loss_G = 0.0
        mean_loss_D = 0.0
        for i, data in enumerate((trainloader), 0):

            data_a, data_c = data['A'], data['C']
            data_a = data_a.to(device)
            data_c = data_c.to(device)
            data_part_locations = data['part_locations']

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()

            # Format batch
            real_batch  = {}
            real_batch['1'] = data_c
            real_batch['2'] = F.interpolate(data['C'], (256, 256)).to(device)
            real_batch['4'] = F.interpolate(data['C'], (128, 128)).to(device)
            real_batch['8'] = F.interpolate(data['C'], (64, 64)).to(device)
            label = torch.full((1,), real_label, dtype=torch.float, device=device)

            # Forward pass real batch through D
            output = netD(real_batch)

            # Calculate loss on all-real batch
            errD_real_1 = criterionD(output['prediction_1'].view(-1), label)
            errD_real_2 = criterionD(output['prediction_2'].view(-1), label)
            errD_real_4 = criterionD(output['prediction_4'].view(-1), label)
            errD_real_8 = criterionD(output['prediction_8'].view(-1), label)
            

            errD_real = errD_real_1 + errD_real_2 + errD_real_4 + errD_real_8


            # Calculate gradients for D in backward pass
            errD_real.backward()


            ## Train with all-fake batch
            fake = netG(data_a, part_locations=data_part_locations)

            # Format batch
            fake_batch  = {}
            fake_batch['1'] = fake.detach()
            fake_batch['2'] = F.interpolate(fake, (256, 256)).to(device)
            fake_batch['4'] = F.interpolate(fake, (128, 128)).to(device)
            fake_batch['8'] = F.interpolate(fake, (64, 64)).to(device)
            label.fill_(fake_label)

            # Classify all fake batch with D
            output_fake = netD(fake_batch)

            # Calculate D's loss on the all-fake batch
            errD_fake_1 = criterionD(output_fake['prediction_1'].view(-1), label)
            errD_fake_2 = criterionD(output_fake['prediction_2'].view(-1), label)
            errD_fake_4 = criterionD(output_fake['prediction_4'].view(-1), label)
            errD_fake_8 = criterionD(output_fake['prediction_8'].view(-1), label)

            errD_fake = errD_fake_1 + errD_fake_2 + errD_fake_4 + errD_fake_8

            # Calculate the gradients for this batch
            errD_fake.backward()
            
            errD = errD_real + errD_fake
            
#             errD.backward()
            
            # Update D
            optimizerD.step()


            ###########################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            fake = netG(data_a, part_locations=data_part_locations)
            mse_loss = criterionG(fake, data_c)
            
            #             VGG_LOSS
            L_p_vgg19 = 0
            perceptual_dst_pred = perceptual_loss_vgg19(fake)
            with torch.no_grad():
                perceptual_dst_img = perceptual_loss_vgg19(data_c)                    
            for k in range(4):
                L_p_vgg19 += weights_layer_perceptual[k] * torch.nn.MSELoss()(perceptual_dst_pred[k],perceptual_dst_img[k])

            L_p = L_p_vgg19
            
            ms_ssim_loss = 1 - ms_ssim_module(fake, data_c)

            # Format batch
            fake_batch  = {}
            fake_batch['1'] = fake
            fake_batch['2'] = F.interpolate(fake, (256, 256)).to(device)
            fake_batch['4'] = F.interpolate(fake, (128, 128)).to(device)
            fake_batch['8'] = F.interpolate(fake, (64, 64)).to(device)
            label.fill_(fake_label)

            label.fill_(real_label)  # fake labels are real for generator cost

            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake_batch)

            # Calculate G's loss based on this output
#             err_hinge_G = hinge_G(output['prediction_1']) + hinge_G(output['prediction_2']) + hinge_G(output['prediction_4']) + hinge_G(output['prediction_8'])
#             adversarial = 4 * output_fake['prediction_1'] + 2 * output_fake['prediction_2'] + output_fake['prediction_4'] + output_fake['prediction_8']
            # Adversarial loss
            adversarial = 4. - (output['prediction_1'].view(-1) + output['prediction_2'].view(-1) + output['prediction_4'].view(-1) + output['prediction_8'].view(-1))
            errG = mse_loss + L_p + adversarial + 100*ms_ssim_loss

            # Calculate gradients for G
            errG.backward()

            # Update G
            optimizerG.step()
            
            writer.add_scalar('perceptual_loss', L_p.item(), epoch * len(trainloader) + i)
            writer.add_scalar('mse_loss', mse_loss.item(), epoch * len(trainloader) + i)
            writer.add_scalar('adversarial', adversarial.item(), epoch * len(trainloader) + i)
            writer.add_scalar('ms-ssim', 100*ms_ssim_loss.item(), epoch * len(trainloader) + i)
            writer.add_scalar('loss_G', errG, epoch * len(trainloader) + i)
            writer.add_scalar('loss_D', errD, epoch * len(trainloader) + i)
            
            
            # Output training stats
            if i % 10 == 0 and i != 0:                
                torch.save(netG.state_dict(), f'weights/netG_epoch_{epoch}_i_{i}.pth')
                torch.save(netD.state_dict(), f'weights/netD_epoch_{epoch}_i_{i}.pth')
                
                # lr, fake, hr
                images = [data_a[0].cpu(), fake[0].cpu(), data_c[0].cpu()]
                report_img = make_grid(images)


                
                #logger.report_image('image', f'epoch_{epoch}, iter_{i}', iteration=epoch * len(trainloader) + i, image=tensor2im(report_img))
                #logger.flush()
                #print(f'epoch: {epoch}, iter: {i}, perceptual_loss: {perceptual_loss}, mse_loss: {mse_loss}, adversarial: {adversarial}, ms-ssim: {ms-ssim}, loss_G: {loss_G}, loss_D: {loss_D}')

                print(f'epoch: {epoch}, iter: {i}, perceptual_loss: {L_p.item():.3f}, mse_loss: {mse_loss.item():.3f}, adversarial: {adversarial.item():.3f}, ms-ssim: {100*ms_ssim_loss.item():.3f}, loss_G: {errG.cpu().detach().numpy()[0]:.3f}, loss_D: {errD.cpu().detach().numpy():.3f}')
                save_image(report_img, f'epoch_{epoch}_i_{i}.png')
                
                # tensorboard logging
                writer.add_scalar('perceptual_loss', L_p.item(), epoch * len(trainloader) + i)
                writer.add_scalar('mse_loss', mse_loss.item(), epoch * len(trainloader) + i)
                writer.add_scalar('adversarial', adversarial.item(), epoch * len(trainloader) + i)
                writer.add_scalar('ms-ssim', 100*ms_ssim_loss.item(), epoch * len(trainloader) + i)
                writer.add_scalar('loss_G', errG.cpu().detach().numpy()[0], epoch * len(trainloader) + i)
                writer.add_scalar('loss_D', errD.cpu().detach().numpy(), epoch * len(trainloader) + i)

def main():
    
    # tensorboard
    writer = SummaryWriter()

    # trains parameters dict
    parameters_dict = {
        'optimizerG': 'Adam (0.5, 0.99)',
        'optimizerD': 'Adam (0.5, 0.99)',
        'learning_rate': '2e-4',
        'dataset': 'celeba (30k)',
        'resolution': '512',
    }
    
    # init trains
    #task = Task.init(project_name='face_enhancement', task_name='Exp. 2.5, res 512, celeba 30k, bce + ms_ssim')
    #logger = task.get_logger()
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 

    dataset = AlignedDataset('dataset_celeba/images', fine_size=512)
    trainloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=8)
    netG = networks.define_G('UNetDictFace', ['cuda:0'])
    netD = MultiScaleDiscriminator(scales=(1, 2, 4, 8))
    netD.to('cuda:0')         
    
#     netG.load_state_dict(torch.load('weights/netG_30k_epoch0_exp2_1.pth'))
#     netD.load_state_dict(torch.load('weights/netD_30k_epoch0_exp2_1.pth'))
    
    cfg_str = str(netG) + str('\n\n Discriminator:\n\n') + str(netD)
    #Task.current_task().set_model_config(cfg_str)
    
    # connect the dictionary to TRAINS Task
    #parameters_dict = Task.current_task().connect(parameters_dict)
    
    train(device, dataset, trainloader, netG, netD, writer)
    
    
    
if __name__ == '__main__':
    main()

In [None]:
#@title custom_dataset.py (paths)
%%writefile /content/DFDNet/data/custom_dataset.py
# -- coding: utf-8 --
import os.path
import os
import random
import torchvision.transforms as transforms
import torch
from PIL import Image, ImageFilter
import numpy as np
import cv2
import math
from scipy.io import loadmat
from PIL import Image
import PIL
from torch.utils.data import Dataset, DataLoader

import glob

class AlignedDataset(Dataset):
    
    def __init__(self, root_dir, fine_size=512, transform=None):
        self.root_dir = '/content/DFDNet/ffhq'
        #self.pathes = [os.path.join(self.root_dir, x) for x in os.listdir(self.root_dir) if x[-3:] == 'jpg']
        
        self.pathes = glob.glob(self.root_dir + '/**/*.png', recursive=True)

        #print("self.pathes")
        #print(self.pathes)
        self.transform = transform
        self.fine_size = fine_size
        self.partpath = '/content/DFDNet/landmarks'
        
    def AddNoise(self,img): # noise
        if random.random() > 0.9: #
            return img
        self.sigma = np.random.randint(1, 11)
        img_tensor = torch.from_numpy(np.array(img)).float()
        noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0)

        noiseimg = torch.clamp(noise+img_tensor,0,255)
        return Image.fromarray(np.uint8(noiseimg.numpy()))

    def AddBlur(self,img): # gaussian blur or motion blur
        if random.random() > 0.9: #
            return img
        img = np.array(img)
        if random.random() > 0.35: ##gaussian blur
            blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15
            blursigma = random.randint(3, 20)
            img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10)
        else: #motion blur
            M = random.randint(1,32)
            KName = './data/MotionBlurKernel/m_%02d.mat' % M
            k = loadmat(KName)['kernel']
            k = k.astype(np.float32)
            k /= np.sum(k)
            img = cv2.filter2D(img,-1,k)
        return Image.fromarray(img)

    def AddDownSample(self,img): # downsampling
        if random.random() > 0.95: #
            return img
        sampler = random.randint(20, 100)*1.0
        img = img.resize((int(self.fine_size/sampler*10.0), int(self.fine_size/sampler*10.0)), Image.BICUBIC)
        return img

    def AddJPEG(self,img): # JPEG compression
        if random.random() > 0.6:
            return img
        imQ = random.randint(40, 80)
        img = np.array(img)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95
        _, encA = cv2.imencode('.jpg', img, encode_param)
        img = cv2.imdecode(encA,1)
        return Image.fromarray(img)

    def AddUpSample(self,img):
        return img.resize((self.fine_size, self.fine_size), Image.BICUBIC)

    def __getitem__(self, index): # indexation

        path = self.pathes[index]
        Imgs = Image.open(path).convert('RGB')
        
        A = Imgs.resize((self.fine_size, self.fine_size))
        A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
        C = A
        A = self.AddBlur(A)
        
        tmps = path.split('/')
        ImgName = tmps[-1]
        part_locations = self.get_part_location(self.partpath, ImgName, 2)
        
        A = transforms.ToTensor()(A)
        C = transforms.ToTensor()(C)
        
        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 
        C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)
        
        return {'A': A, 'C': C, 'path': path, 'part_locations': part_locations}

    def get_part_location(self, landmarkpath, imgname, downscale=1):
        Landmarks = []
        with open(os.path.join(landmarkpath, imgname + '.txt'),'r') as f:
            for line in f:
                tmp = [np.float(i) for i in line.split(' ') if i != '\n']
                Landmarks.append(tmp)
        Landmarks = np.array(Landmarks)/downscale # 512 * 512
        
        Map_LE = list(np.hstack((range(17,22), range(36,42))))
        Map_RE = list(np.hstack((range(22,27), range(42,48))))
        Map_NO = list(range(29,36))
        Map_MO = list(range(48,68))
        #left eye
        Mean_LE = np.mean(Landmarks[Map_LE],0)
        L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
        Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
        #right eye
        Mean_RE = np.mean(Landmarks[Map_RE],0)
        L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
        Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
        #nose
        Mean_NO = np.mean(Landmarks[Map_NO],0)
        L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
        Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
        #mouth
        Mean_MO = np.mean(Landmarks[Map_MO],0)
        L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))

        Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
        return Location_LE, Location_RE, Location_NO, Location_MO

    def __len__(self): #
        return len(self.pathes)

    def name(self):
        return 'AlignedDataset'

In [None]:
#@title networks.py (debugging prints (disabled))
%%writefile /content/DFDNet/models/networks.py
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.nn import Parameter as P
from util import util
from torchvision import models
import scipy.io as sio
import numpy as np
import scipy.ndimage
import torch.nn.utils.spectral_norm as SpectralNorm

from torch.autograd import Function
from math import sqrt
import random
import os
import math

from sync_batchnorm import convert_model
####

###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)

    return norm_layer


def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)

    return scheduler


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_flag=True):
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net = convert_model(net)
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    if init_flag:

        init_weights(net, init_type, gain=init_gain)

    return net


# compute adaptive instance norm
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 3)
    C, _ = size[:2]
    feat_var = feat.contiguous().view(C, -1).var(dim=1) + eps
    feat_std = feat_var.sqrt().view(C, 1, 1)
    feat_mean = feat.contiguous().view(C, -1).mean(dim=1).view(C, 1, 1)

    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):  # content_feat is degraded feature, style is ref feature
    assert (content_feat.size()[:1] == style_feat.size()[:1])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)

    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def calc_mean_std_4D(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
    # assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std_4D(style_feat)
    content_mean, content_std = calc_mean_std_4D(content_feat)
    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std + style_mean

def define_G(which_model_netG, gpu_ids=[]):
    if which_model_netG == 'UNetDictFace':
        netG = UNetDictFace(64)
        init_flag = False
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    return init_net(netG, 'normal', 0.02, gpu_ids, init_flag)


##############################################################################
# Classes
############################################################################################################################################


def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
    return nn.Sequential(
        SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
#         conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
#         nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
        SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
    )
class MSDilateBlock(nn.Module):
    def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
        super(MSDilateBlock, self).__init__()
        self.conv1 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
        self.conv2 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
        self.conv3 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
        self.conv4 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
        self.convi =  SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        cat  = torch.cat([conv1, conv2, conv3, conv4], 1)
        out = self.convi(cat) + x
        return out

##############################UNetFace#########################
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channel)

    def forward(self, input, style):
        style_mean, style_std = calc_mean_std_4D(style)
        out = self.norm(input)
        size = input.size()
        out = style_std.expand(size) * out + style_mean.expand(size)
        return out

class BlurFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        grad_input = F.conv2d(
            grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
        )
        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = F.conv2d(
            gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
        )
        return grad_input, None, None


class BlurFunction(Function):
    @staticmethod
    def forward(ctx, input, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)

        return grad_input, None, None

blur = BlurFunction.apply


class Blur(nn.Module):
    def __init__(self, channel):
        super().__init__()

        weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
        weight = weight.view(1, 1, 3, 3)
        weight = weight / weight.sum()
        weight_flip = torch.flip(weight, [2, 3])

        self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
        self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))

    def forward(self, input):
        return blur(input, self.weight, self.weight_flip)

class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()
        return weight * sqrt(2 / fan_in)
    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)

def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)
    return module

class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)
    def forward(self, input):
        return self.conv(input)

class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
    def forward(self, image, noise):
        return image + self.weight * noise

class StyledUpBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
        super().__init__()
        if upsample:
            self.conv1 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                Blur(out_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        else:
            self.conv1 = nn.Sequential(
                Blur(in_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        self.convup = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
                # Blur(out_channel),
            )
        # self.noise1 = equal_lr(NoiseInjection(out_channel))
        # self.adain1 = AdaptiveInstanceNorm(out_channel)
        self.lrelu1 = nn.LeakyReLU(0.2)

        # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        # self.noise2 = equal_lr(NoiseInjection(out_channel))
        # self.adain2 = AdaptiveInstanceNorm(out_channel)
        # self.lrelu2 = nn.LeakyReLU(0.2)

        self.ScaleModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        self.ShiftModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
            nn.Sigmoid(),
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
       
    def forward(self, input, style):
        out = self.conv1(input)
#         out = self.noise1(out, noise)
        out = self.lrelu1(out)

        Shift1 = self.ShiftModel1(style)
        Scale1 = self.ScaleModel1(style)
        out = out * Scale1 + Shift1
        # out = self.adain1(out, style)
        outup = self.convup(out)

        return outup

##############################################################################
##Face Dictionary
##############################################################################
class VGGFeat(torch.nn.Module):
    """
    Input: (B, C, H, W), RGB, [-1, 1]
    """
    def __init__(self, weight_path='./weights/vgg19.pth'):
        super().__init__()
        self.model = models.vgg19(pretrained=False)
        self.build_vgg_layers()
        
        self.model.load_state_dict(torch.load(weight_path))

        self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)))
        self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)))
        
        # self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
    
    def build_vgg_layers(self):
        vgg_pretrained_features = self.model.features
        self.features = []
        # feature_layers = [0, 3, 8, 17, 26, 35]
        feature_layers = [0, 8, 17, 26, 35]
        for i in range(len(feature_layers)-1): 
            module_layers = torch.nn.Sequential() 
            for j in range(feature_layers[i], feature_layers[i+1]):
                module_layers.add_module(str(j), vgg_pretrained_features[j])
            self.features.append(module_layers)
        self.features = torch.nn.ModuleList(self.features)

    def preprocess(self, x):
        x = (x + 1) / 2
        x = (x - self.RGB_mean) / self.RGB_std
        if x.shape[3] < 224:
            x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        return x

    def forward(self, x):
        x = self.preprocess(x)
        features = []
        for m in self.features:
            # print(m)
            x = m(x)
            features.append(x)
        return features 

def compute_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x
def ToRGB(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1))
    )

def AttentionBlock(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))
    )

class UNetDictFace(nn.Module):
    def __init__(self, ngf=64, dictionary_path='./DictionaryCenter512'):
        super().__init__()
        
        self.part_sizes = np.array([80,80,50,110]) # size for 512
        self.feature_sizes = np.array([256,128,64,32])
        self.channel_sizes = np.array([128,256,512,512])
        Parts = ['left_eye','right_eye','nose','mouth']
        self.Dict_256 = {}
        self.Dict_128 = {}
        self.Dict_64 = {}
        self.Dict_32 = {}
        for j,i in enumerate(Parts):
            f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True))

            f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2)
            max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_256[i] = f_256_reshape #/ max_256

            f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True))

            f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4)
            max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_128[i] = f_128_reshape #/ max_128

            f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True))

            f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8)
            max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_64[i] = f_64_reshape #/ max_64

            f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True))

            f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16)
            max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_32[i] = f_32_reshape #/ max_32

        """
        print("self.Dict_256")
        print(len(self.Dict_256))
        print("self.Dict_256['left_eye'].shape")
        print(self.Dict_256['left_eye'].shape)
        print("self.Dict_256['right_eye'].shape")
        print(self.Dict_256['right_eye'].shape)
        print("self.Dict_256['nose'].shape")
        print(self.Dict_256['nose'].shape)
        print("self.Dict_256['mouth'].shape")
        print(self.Dict_256['mouth'].shape)
        print("-------------------")
        print("self.Dict_128")
        print(len(self.Dict_128))
        print("self.Dict_128['left_eye'].shape")
        print(self.Dict_128['left_eye'].shape)
        print("self.Dict_128['right_eye'].shape")
        print(self.Dict_128['right_eye'].shape)
        print("self.Dict_128['nose'].shape")
        print(self.Dict_128['nose'].shape)
        print("self.Dict_128['mouth'].shape")
        print(self.Dict_128['mouth'].shape)
        print("-------------------")
        print("self.Dict_64")
        print(len(self.Dict_64))
        print("self.Dict_64['left_eye'].shape")
        print(self.Dict_64['left_eye'].shape)
        print("self.Dict_64['right_eye'].shape")
        print(self.Dict_64['right_eye'].shape)
        print("self.Dict_64['nose'].shape")
        print(self.Dict_64['nose'].shape)
        print("self.Dict_64['mouth'].shape")
        print(self.Dict_64['mouth'].shape)
        print("-------------------")
        print("self.Dict_32")
        print(len(self.Dict_32))
        print("self.Dict_32['left_eye'].shape")
        print(self.Dict_32['left_eye'].shape)
        print("self.Dict_32['right_eye'].shape")
        print(self.Dict_32['right_eye'].shape)
        print("self.Dict_32['nose'].shape")
        print(self.Dict_32['nose'].shape)
        print("self.Dict_32['mouth'].shape")
        print(self.Dict_32['mouth'].shape)
        """

        self.le_256 = AttentionBlock(128)
        self.le_128 = AttentionBlock(256)
        self.le_64 = AttentionBlock(512)
        self.le_32 = AttentionBlock(512)

        self.re_256 = AttentionBlock(128)
        self.re_128 = AttentionBlock(256)
        self.re_64 = AttentionBlock(512)
        self.re_32 = AttentionBlock(512)

        self.no_256 = AttentionBlock(128)
        self.no_128 = AttentionBlock(256)
        self.no_64 = AttentionBlock(512)
        self.no_32 = AttentionBlock(512)

        self.mo_256 = AttentionBlock(128)
        self.mo_128 = AttentionBlock(256)
        self.mo_64 = AttentionBlock(512)
        self.mo_32 = AttentionBlock(512)

        #norm
        self.VggExtract = VGGFeat()
        
        ######################
        self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1])  #

        self.up0 = StyledUpBlock(ngf*8,ngf*8)
        self.up1 = StyledUpBlock(ngf*8, ngf*4) #
        self.up2 = StyledUpBlock(ngf*4, ngf*2) #
        self.up3 = StyledUpBlock(ngf*2, ngf) #
        self.up4 = nn.Sequential( # 128
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
            # nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            UpResBlock(ngf),
            UpResBlock(ngf),
            # SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
            nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        self.to_rgb0 = ToRGB(ngf*8)
        self.to_rgb1 = ToRGB(ngf*4)
        self.to_rgb2 = ToRGB(ngf*2)
        self.to_rgb3 = ToRGB(ngf*1)

        # for param in self.BlurInputConv.parameters():
        #     param.requires_grad = False
    
    def forward(self, input, part_locations):
        #print("input.shape")
        #print(input.shape)
        VggFeatures = self.VggExtract(input) #VggFeatures = list object
        # for b in range(input.size(0)):
        b = 0
        UpdateVggFeatures = []
        for i, f_size in enumerate(self.feature_sizes):
            cur_feature = VggFeatures[i]
            #print("cur_feature.shape")
            #print(cur_feature.shape)

            update_feature = cur_feature.clone() #* 0
            cur_part_sizes = self.part_sizes // (512/f_size)
            dicts_feature = getattr(self, 'Dict_'+str(f_size))
            
            LE_Dict_feature = dicts_feature['left_eye'].to(input)
            RE_Dict_feature = dicts_feature['right_eye'].to(input)
            NO_Dict_feature = dicts_feature['nose'].to(input)
            MO_Dict_feature = dicts_feature['mouth'].to(input)

            le_location = (part_locations[0][b] // (512/f_size)).int()
            re_location = (part_locations[1][b] // (512/f_size)).int()
            no_location = (part_locations[2][b] // (512/f_size)).int()
            mo_location = (part_locations[3][b] // (512/f_size)).int()

            LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone()
            RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone()
            NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone()
            MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone()
            
            #resize
            LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            
            #print("LE_feature_resize.shape")
            #print(LE_feature_resize.shape)

            LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize)
            RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize)
            NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize)
            MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize)
            
            LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm)

            LE_score = F.softmax(LE_score.view(-1),dim=0)
            LE_index = torch.argmax(LE_score)
            LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3)))

            LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature)
            LE_Att_feature = LE_Attention * LE_Swap_feature
            

            RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm)
            RE_score = F.softmax(RE_score.view(-1),dim=0)
            RE_index = torch.argmax(RE_score)
            RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3)))
            
            RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature)
            RE_Att_feature = RE_Attention * RE_Swap_feature

            NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm)
            NO_score = F.softmax(NO_score.view(-1),dim=0)
            NO_index = torch.argmax(NO_score)
            NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3)))
            
            NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature)
            NO_Att_feature = NO_Attention * NO_Swap_feature

            
            MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm)
            MO_score = F.softmax(MO_score.view(-1),dim=0)
            MO_index = torch.argmax(MO_score)
            MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3)))
            
            MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature)
            MO_Att_feature = MO_Attention * MO_Swap_feature

            update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature
            update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature
            update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature
            update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature

            UpdateVggFeatures.append(update_feature) 
        
        fea_vgg = self.MSDilate(VggFeatures[3])
        #new version
        fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3])
        # out1 = F.interpolate(fea_up0,(512,512))
        # out1 = self.to_rgb0(out1)

        fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) #
        # out2 = F.interpolate(fea_up1,(512,512))
        # out2 = self.to_rgb1(out2)

        fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) #
        # out3 = F.interpolate(fea_up2,(512,512))
        # out3 = self.to_rgb2(out3)

        fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) #
        # out4 = F.interpolate(fea_up3,(512,512))
        # out4 = self.to_rgb3(out4)

        output = self.up4(fea_up3) #
        
    
        return output  #+ out4 + out3 + out2 + out1
        #0 128 * 256 * 256
        #1 256 * 128 * 128
        #2 512 * 64 * 64
        #3 512 * 32 * 32


class UpResBlock(nn.Module):
    def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
        super(UpResBlock, self).__init__()
        self.Model = nn.Sequential(
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
            # norm_layer(dim),
            nn.LeakyReLU(0.2,True),
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
        )
    def forward(self, x):
        out = x + self.Model(x)
        return out

class VggClassNet(nn.Module):
    def __init__(self, select_layer = ['0','5','10','19']):
        super(VggClassNet, self).__init__()
        self.select = select_layer
        self.vgg = models.vgg19(pretrained=True).features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features


if __name__ == '__main__':
    print('this is network')




In [None]:
%cd /content/DFDNet
!python run.py --batchSize 1

# Testing

Add data to ``/content/input/<image.png>``. Currently just searching for ``.png``. Change the extention inside ``custom_dataset.py`` if you want. Change path to your own model. Output in ``/content/``. Also, don't forget to create landmarks.

In [None]:
#@title inference.py (fixing import, changing paths)
%%writefile /content/DFDNet/inference.py
import os
import numpy as np
import cv2

import math
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch
import random
from skimage import transform as trans
from skimage import io
import sys
sys.path.append('FaceLandmarkDetection')
import face_alignment
import dlib

from models import *
from data.custom_dataset import AlignedDataset
from torch.utils.data import DataLoader
import torchvision

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from trains import Task
from torchvision.utils import make_grid

from options.test_options import TestOptions
from data.image_folder import make_dataset

from torchvision.utils import save_image


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 

# values don't matter, configure custom_dataset.py
dataset = AlignedDataset('/content/checked/', fine_size=256)
trainloader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=1)
netG = networks.define_G('UNetDictFace', ['cuda:0'])

# # tensorboard
# writer = SummaryWriter()

# # trains parameters dict
# parameters_dict = {
#     'test': 'test',
#     }

# # init trains
# task = Task.init(project_name='face_enhancement', task_name='test')
# logger = task.get_logger()

# cfg_str = str(netG) + str('\n\n Discriminator:\n\n') + str(netD)
# Task.current_task().set_model_config(cfg_str)

# # connect the dictionary to TRAINS Task
# parameters_dict = Task.current_task().connect(parameters_dict)

#netG.load_state_dict(torch.load('weights/netG_30k_epoch4_exp2_3.pth'))
netG.load_state_dict(torch.load('/content/DFDNet/weights/netG_epoch_0_i_10.pth'))

for i, data in enumerate(tqdm(trainloader), 0):
    data_a, data_c = data['A'], data['C']
    data_a = data_a.to(device)
    data_c = data_c.to(device)
    data_part_locations = data['part_locations']

    out = netG(data_a, part_locations=data_part_locations)
    images = [data_a[0].cpu(), out[0].cpu()]
    report_img = make_grid(images)
#     report_img = report_img.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    result = 255 * (report_img.permute(1, 2, 0).cpu().detach().numpy() + 1) / 2
    cv2.imwrite(f'/content/{i}.jpg', cv2.cvtColor(result, cv2.COLOR_BGR2RGB))



In [None]:
#@title custom_dataset.py (paths, disabling augmentations, disabling resize)
%%writefile /content/DFDNet/data/custom_dataset.py
# -- coding: utf-8 --
import os.path
import os
import random
import torchvision.transforms as transforms
import torch
from PIL import Image, ImageFilter
import numpy as np
import cv2
import math
from scipy.io import loadmat
from PIL import Image
import PIL
from torch.utils.data import Dataset, DataLoader

import glob

class AlignedDataset(Dataset):
    
    def __init__(self, root_dir, fine_size=512, transform=None):
        self.root_dir = '/content/input'
        #self.pathes = [os.path.join(self.root_dir, x) for x in os.listdir(self.root_dir) if x[-3:] == 'jpg']
        
        self.pathes = glob.glob(self.root_dir + '/**/*.png', recursive=True)
        #files = glob.glob(self.root_dir + '/**/*.png', recursive=True)
        #files_jpg = glob.glob(self.root_dir + '/**/*.jpg', recursive=True)
        #self.pathes = files.extend(files_jpg)


        #print("self.pathes")
        #print(self.pathes)
        self.transform = transform
        self.fine_size = fine_size
        self.partpath = '/content/landmark_output'
        
    def AddNoise(self,img): # noise
        if random.random() > 0.9: #
            return img
        self.sigma = np.random.randint(1, 11)
        img_tensor = torch.from_numpy(np.array(img)).float()
        noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0)

        noiseimg = torch.clamp(noise+img_tensor,0,255)
        return Image.fromarray(np.uint8(noiseimg.numpy()))

    def AddBlur(self,img): # gaussian blur or motion blur
        if random.random() > 0.9: #
            return img
        img = np.array(img)
        if random.random() > 0.35: ##gaussian blur
            blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15
            blursigma = random.randint(3, 20)
            img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10)
        else: #motion blur
            M = random.randint(1,32)
            KName = './data/MotionBlurKernel/m_%02d.mat' % M
            k = loadmat(KName)['kernel']
            k = k.astype(np.float32)
            k /= np.sum(k)
            img = cv2.filter2D(img,-1,k)
        return Image.fromarray(img)

    def AddDownSample(self,img): # downsampling
        if random.random() > 0.95: #
            return img
        sampler = random.randint(20, 100)*1.0
        img = img.resize((int(self.fine_size/sampler*10.0), int(self.fine_size/sampler*10.0)), Image.BICUBIC)
        return img

    def AddJPEG(self,img): # JPEG compression
        if random.random() > 0.6:
            return img
        imQ = random.randint(40, 80)
        img = np.array(img)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95
        _, encA = cv2.imencode('.jpg', img, encode_param)
        img = cv2.imdecode(encA,1)
        return Image.fromarray(img)

    def AddUpSample(self,img):
        return img.resize((self.fine_size, self.fine_size), Image.BICUBIC)

    def __getitem__(self, index): # indexation

        path = self.pathes[index]
        Imgs = Image.open(path).convert('RGB')
        
        #A = Imgs.resize((self.fine_size, self.fine_size))
        A = Imgs
        #A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
        C = A
        #A = self.AddBlur(A)
        
        tmps = path.split('/')
        ImgName = tmps[-1]
        part_locations = self.get_part_location(self.partpath, ImgName, 2)
        
        A = transforms.ToTensor()(A)
        C = transforms.ToTensor()(C)
        
        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 
        C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)
        
        return {'A': A, 'C': C, 'path': path, 'part_locations': part_locations}

    def get_part_location(self, landmarkpath, imgname, downscale=1):
        Landmarks = []
        with open(os.path.join(landmarkpath, imgname + '.txt'),'r') as f:
            for line in f:
                tmp = [np.float(i) for i in line.split(' ') if i != '\n']
                Landmarks.append(tmp)
        Landmarks = np.array(Landmarks)/downscale # 512 * 512
        
        Map_LE = list(np.hstack((range(17,22), range(36,42))))
        Map_RE = list(np.hstack((range(22,27), range(42,48))))
        Map_NO = list(range(29,36))
        Map_MO = list(range(48,68))
        #left eye
        Mean_LE = np.mean(Landmarks[Map_LE],0)
        L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
        Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
        #right eye
        Mean_RE = np.mean(Landmarks[Map_RE],0)
        L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
        Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
        #nose
        Mean_NO = np.mean(Landmarks[Map_NO],0)
        L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
        Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
        #mouth
        Mean_MO = np.mean(Landmarks[Map_MO],0)
        L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))

        Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
        return Location_LE, Location_RE, Location_NO, Location_MO

    def __len__(self): #
        return len(self.pathes)

    def name(self):
        return 'AlignedDataset'

In [None]:
%cd /content/DFDNet
!python inference.py

# Landmark generation

Install that and restart runtime.

In [None]:
!pip install face-alignment
!pip install matplotlib --upgrade

In [None]:
%cd /content/
import face_alignment
from skimage import io
import numpy as np
import glob
from tqdm import tqdm
import os
import shutil

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)

unchecked_input_path = '/content/input' #@param {type:"string"}
checked_output_path = '/content/checked' #@param {type:"string"}
failed_output_path = '/content/failed' #@param {type:"string"}
landmark_output_path = '/content/landmark_output' #@param {type:"string"}

if not os.path.exists(unchecked_input_path):
    os.makedirs(unchecked_input_path)
if not os.path.exists(checked_output_path):
    os.makedirs(checked_output_path)
if not os.path.exists(failed_output_path):
    os.makedirs(failed_output_path)
if not os.path.exists(landmark_output_path):
    os.makedirs(landmark_output_path)

files = glob.glob(unchecked_input_path + '/**/*.png', recursive=True)
files_jpg = glob.glob(unchecked_input_path + '/**/*.jpg', recursive=True)
files.extend(files_jpg)
err_files=[]

for f in tqdm(files):
  input = io.imread(f)
  preds = fa.get_landmarks(input)
  #print(preds)
  if preds is not None:
    np.savetxt(os.path.join(landmark_output_path, os.path.basename(f)+".txt"), preds[0], delimiter=' ', fmt='%1.3f')   # X is an array
    shutil.move(f, os.path.join(checked_output_path,os.path.basename(f)))
  else:
    shutil.move(f, os.path.join(failed_output_path,os.path.basename(f)))

# [Experimental] Training with own features

The goal is to avoid using the already provided ```.npy``` files and creating the needed feature files manually instead. Feature saving was added into ```networks.py```, so you can use good looking images to extract features from. It will train as normal, but during training features will be concentrated into one array and then saved with torch as a file after a certain amount of iterations. Change ```self.amount_features``` if you want to configure the amount of features. The default is 20. Just wait until it says that all the features were saved (one print for each size (32,64,128,256)).

In [None]:
#@title networks.py (added feature generation)
%%writefile /content/DFDNet/models/networks.py
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.nn import Parameter as P
from util import util
from torchvision import models
import scipy.io as sio
import numpy as np
import scipy.ndimage
import torch.nn.utils.spectral_norm as SpectralNorm

from torch.autograd import Function
from math import sqrt
import random
import os
import math

from sync_batchnorm import convert_model
####

###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)

    return norm_layer


def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)

    return scheduler


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_flag=True):
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net = convert_model(net)
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    if init_flag:

        init_weights(net, init_type, gain=init_gain)

    return net


# compute adaptive instance norm
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 3)
    C, _ = size[:2]
    feat_var = feat.contiguous().view(C, -1).var(dim=1) + eps
    feat_std = feat_var.sqrt().view(C, 1, 1)
    feat_mean = feat.contiguous().view(C, -1).mean(dim=1).view(C, 1, 1)

    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):  # content_feat is degraded feature, style is ref feature
    assert (content_feat.size()[:1] == style_feat.size()[:1])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)

    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def calc_mean_std_4D(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
    # assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std_4D(style_feat)
    content_mean, content_std = calc_mean_std_4D(content_feat)
    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std + style_mean

def define_G(which_model_netG, gpu_ids=[]):
    if which_model_netG == 'UNetDictFace':
        netG = UNetDictFace(64)
        init_flag = False
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    return init_net(netG, 'normal', 0.02, gpu_ids, init_flag)


##############################################################################
# Classes
############################################################################################################################################


def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
    return nn.Sequential(
        SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
#         conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
#         nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
        SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
    )
class MSDilateBlock(nn.Module):
    def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
        super(MSDilateBlock, self).__init__()
        self.conv1 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
        self.conv2 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
        self.conv3 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
        self.conv4 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
        self.convi =  SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        cat  = torch.cat([conv1, conv2, conv3, conv4], 1)
        out = self.convi(cat) + x
        return out

##############################UNetFace#########################
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channel)

    def forward(self, input, style):
        style_mean, style_std = calc_mean_std_4D(style)
        out = self.norm(input)
        size = input.size()
        out = style_std.expand(size) * out + style_mean.expand(size)
        return out

class BlurFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        grad_input = F.conv2d(
            grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
        )
        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = F.conv2d(
            gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
        )
        return grad_input, None, None


class BlurFunction(Function):
    @staticmethod
    def forward(ctx, input, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)

        return grad_input, None, None

blur = BlurFunction.apply


class Blur(nn.Module):
    def __init__(self, channel):
        super().__init__()

        weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
        weight = weight.view(1, 1, 3, 3)
        weight = weight / weight.sum()
        weight_flip = torch.flip(weight, [2, 3])

        self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
        self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))

    def forward(self, input):
        return blur(input, self.weight, self.weight_flip)

class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()
        return weight * sqrt(2 / fan_in)
    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)

def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)
    return module

class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)
    def forward(self, input):
        return self.conv(input)

class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
    def forward(self, image, noise):
        return image + self.weight * noise

class StyledUpBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
        super().__init__()
        if upsample:
            self.conv1 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                Blur(out_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        else:
            self.conv1 = nn.Sequential(
                Blur(in_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        self.convup = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
                # Blur(out_channel),
            )
        # self.noise1 = equal_lr(NoiseInjection(out_channel))
        # self.adain1 = AdaptiveInstanceNorm(out_channel)
        self.lrelu1 = nn.LeakyReLU(0.2)

        # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        # self.noise2 = equal_lr(NoiseInjection(out_channel))
        # self.adain2 = AdaptiveInstanceNorm(out_channel)
        # self.lrelu2 = nn.LeakyReLU(0.2)

        self.ScaleModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        self.ShiftModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
            nn.Sigmoid(),
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
       
    def forward(self, input, style):
        out = self.conv1(input)
#         out = self.noise1(out, noise)
        out = self.lrelu1(out)

        Shift1 = self.ShiftModel1(style)
        Scale1 = self.ScaleModel1(style)
        out = out * Scale1 + Shift1
        # out = self.adain1(out, style)
        outup = self.convup(out)

        return outup

##############################################################################
##Face Dictionary
##############################################################################
class VGGFeat(torch.nn.Module):
    """
    Input: (B, C, H, W), RGB, [-1, 1]
    """
    def __init__(self, weight_path='./weights/vgg19.pth'):
        super().__init__()
        self.model = models.vgg19(pretrained=False)
        self.build_vgg_layers()
        
        self.model.load_state_dict(torch.load(weight_path))

        self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)))
        self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)))
        
        # self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
    
    def build_vgg_layers(self):
        vgg_pretrained_features = self.model.features
        self.features = []
        # feature_layers = [0, 3, 8, 17, 26, 35]
        feature_layers = [0, 8, 17, 26, 35]
        for i in range(len(feature_layers)-1): 
            module_layers = torch.nn.Sequential() 
            for j in range(feature_layers[i], feature_layers[i+1]):
                module_layers.add_module(str(j), vgg_pretrained_features[j])
            self.features.append(module_layers)
        self.features = torch.nn.ModuleList(self.features)

    def preprocess(self, x):
        x = (x + 1) / 2
        x = (x - self.RGB_mean) / self.RGB_std
        if x.shape[3] < 224:
            x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        return x

    def forward(self, x):
        x = self.preprocess(x)
        features = []
        for m in self.features:
            # print(m)
            x = m(x)
            features.append(x)
        return features 

def compute_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x
def ToRGB(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1))
    )

def AttentionBlock(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))
    )

class UNetDictFace(nn.Module):
    def __init__(self, ngf=64, dictionary_path='./DictionaryCenter512'):
        super().__init__()
        
        self.part_sizes = np.array([80,80,50,110]) # size for 512
        self.feature_sizes = np.array([256,128,64,32])
        self.channel_sizes = np.array([128,256,512,512])
        Parts = ['left_eye','right_eye','nose','mouth']
        self.Dict_256 = {}
        self.Dict_128 = {}
        self.Dict_64 = {}
        self.Dict_32 = {}
        
        for j,i in enumerate(Parts):
            f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True))

            f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2)
            max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_256[i] = f_256_reshape #/ max_256

            f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True))

            f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4)
            max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_128[i] = f_128_reshape #/ max_128

            f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True))

            f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8)
            max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_64[i] = f_64_reshape #/ max_64

            f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True))

            f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16)
            max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_32[i] = f_32_reshape #/ max_32

        """
        print("self.Dict_256")
        print(len(self.Dict_256))
        print("self.Dict_256['left_eye'].shape")
        print(self.Dict_256['left_eye'].shape)
        print("self.Dict_256['right_eye'].shape")
        print(self.Dict_256['right_eye'].shape)
        print("self.Dict_256['nose'].shape")
        print(self.Dict_256['nose'].shape)
        print("self.Dict_256['mouth'].shape")
        print(self.Dict_256['mouth'].shape)
        print("-------------------")
        print("self.Dict_128")
        print(len(self.Dict_128))
        print("self.Dict_128['left_eye'].shape")
        print(self.Dict_128['left_eye'].shape)
        print("self.Dict_128['right_eye'].shape")
        print(self.Dict_128['right_eye'].shape)
        print("self.Dict_128['nose'].shape")
        print(self.Dict_128['nose'].shape)
        print("self.Dict_128['mouth'].shape")
        print(self.Dict_128['mouth'].shape)
        print("-------------------")
        print("self.Dict_64")
        print(len(self.Dict_64))
        print("self.Dict_64['left_eye'].shape")
        print(self.Dict_64['left_eye'].shape)
        print("self.Dict_64['right_eye'].shape")
        print(self.Dict_64['right_eye'].shape)
        print("self.Dict_64['nose'].shape")
        print(self.Dict_64['nose'].shape)
        print("self.Dict_64['mouth'].shape")
        print(self.Dict_64['mouth'].shape)
        print("-------------------")
        print("self.Dict_32")
        print(len(self.Dict_32))
        print("self.Dict_32['left_eye'].shape")
        print(self.Dict_32['left_eye'].shape)
        print("self.Dict_32['right_eye'].shape")
        print(self.Dict_32['right_eye'].shape)
        print("self.Dict_32['nose'].shape")
        print(self.Dict_32['nose'].shape)
        print("self.Dict_32['mouth'].shape")
        print(self.Dict_32['mouth'].shape)
        """

        self.le_256 = AttentionBlock(128)
        self.le_128 = AttentionBlock(256)
        self.le_64 = AttentionBlock(512)
        self.le_32 = AttentionBlock(512)

        self.re_256 = AttentionBlock(128)
        self.re_128 = AttentionBlock(256)
        self.re_64 = AttentionBlock(512)
        self.re_32 = AttentionBlock(512)

        self.no_256 = AttentionBlock(128)
        self.no_128 = AttentionBlock(256)
        self.no_64 = AttentionBlock(512)
        self.no_32 = AttentionBlock(512)

        self.mo_256 = AttentionBlock(128)
        self.mo_128 = AttentionBlock(256)
        self.mo_64 = AttentionBlock(512)
        self.mo_32 = AttentionBlock(512)

        #norm
        self.VggExtract = VGGFeat()
        
        ######################
        self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1])  #

        self.up0 = StyledUpBlock(ngf*8,ngf*8)
        self.up1 = StyledUpBlock(ngf*8, ngf*4) #
        self.up2 = StyledUpBlock(ngf*4, ngf*2) #
        self.up3 = StyledUpBlock(ngf*2, ngf) #
        self.up4 = nn.Sequential( # 128
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
            # nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            UpResBlock(ngf),
            UpResBlock(ngf),
            # SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
            nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        self.to_rgb0 = ToRGB(ngf*8)
        self.to_rgb1 = ToRGB(ngf*4)
        self.to_rgb2 = ToRGB(ngf*2)
        self.to_rgb3 = ToRGB(ngf*1)

        # for param in self.BlurInputConv.parameters():
        #     param.requires_grad = False
    


        self.count_256 = 0
        self.count_128 = 0
        self.count_64 = 0
        self.count_32 = 0
        self.amount_features = 20

    def forward(self, input, part_locations):
        #print("input.shape")
        #print(input.shape)
        VggFeatures = self.VggExtract(input) #VggFeatures = list object
        # for b in range(input.size(0)):
        b = 0
        UpdateVggFeatures = []
        for i, f_size in enumerate(self.feature_sizes):
            cur_feature = VggFeatures[i]
            #print("cur_feature.shape")
            #print(cur_feature.shape)

            update_feature = cur_feature.clone() #* 0
            cur_part_sizes = self.part_sizes // (512/f_size)
            dicts_feature = getattr(self, 'Dict_'+str(f_size))
            
            LE_Dict_feature = dicts_feature['left_eye'].to(input)
            RE_Dict_feature = dicts_feature['right_eye'].to(input)
            NO_Dict_feature = dicts_feature['nose'].to(input)
            MO_Dict_feature = dicts_feature['mouth'].to(input)

            le_location = (part_locations[0][b] // (512/f_size)).int()
            re_location = (part_locations[1][b] // (512/f_size)).int()
            no_location = (part_locations[2][b] // (512/f_size)).int()
            mo_location = (part_locations[3][b] // (512/f_size)).int()

            LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone()
            RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone()
            NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone()
            MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone()
            
            #resize
            LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            
            #print("LE_feature_resize.shape")
            #print(LE_feature_resize.shape)
            
            #print("f_size")
            #print(f_size)

            if f_size == 256:
              if self.count_256 == 0:
                #print("LE_save_256 = cur_feature")
                self.LE_save_256 = LE_feature_resize
                self.RE_save_256 = RE_feature_resize
                self.NO_save_256 = NO_feature_resize
                self.MO_save_256 = MO_feature_resize
                self.count_256 += 1
              else:
                #print("torch.cat((LE_save_256, LE_feature_resize), 1)")
                self.LE_save_256 = torch.cat((self.LE_save_256, LE_feature_resize), 0)
                self.RE_save_256 = torch.cat((self.RE_save_256, RE_feature_resize), 0)
                self.NO_save_256 = torch.cat((self.NO_save_256, NO_feature_resize), 0)
                self.MO_save_256 = torch.cat((self.MO_save_256, MO_feature_resize), 0)
                self.count_256 += 1

              if self.count_256 == self.amount_features:
                torch.save(self.LE_save_256, 'LE_feature_resize_256.pt')
                torch.save(self.RE_save_256, 'RE_feature_resize_256.pt')
                torch.save(self.NO_save_256, 'NO_feature_resize_256.pt')
                torch.save(self.MO_save_256, 'MO_feature_resize_256.pt')
                print("generated features for size 256")
              
            #############################################
            if f_size == 128:
              if self.count_128 == 0:
                self.LE_save_128 = LE_feature_resize
                self.RE_save_128 = RE_feature_resize
                self.NO_save_128 = NO_feature_resize
                self.MO_save_128 = MO_feature_resize
                self.count_128 += 1
              else:
                self.LE_save_128 = torch.cat((self.LE_save_128, LE_feature_resize), 0)
                self.RE_save_128 = torch.cat((self.RE_save_128, RE_feature_resize), 0)
                self.NO_save_128 = torch.cat((self.NO_save_128, NO_feature_resize), 0)
                self.MO_save_128 = torch.cat((self.MO_save_128, MO_feature_resize), 0)
                self.count_128 += 1

              if self.count_128 == self.amount_features:
                torch.save(self.LE_save_128, 'LE_feature_resize_128.pt')
                torch.save(self.RE_save_128, 'RE_feature_resize_128.pt')
                torch.save(self.NO_save_128, 'NO_feature_resize_128.pt')
                torch.save(self.MO_save_128, 'MO_feature_resize_128.pt')
                print("generated features for size 128")
            #############################################
            if f_size == 64:
              if self.count_64 == 0:
                self.LE_save_64 = LE_feature_resize
                self.RE_save_64 = RE_feature_resize
                self.NO_save_64 = NO_feature_resize
                self.MO_save_64 = MO_feature_resize
                self.count_64 += 1
              else:
                self.LE_save_64 = torch.cat((self.LE_save_64, LE_feature_resize), 0)
                self.RE_save_64 = torch.cat((self.RE_save_64, RE_feature_resize), 0)
                self.NO_save_64 = torch.cat((self.NO_save_64, NO_feature_resize), 0)
                self.MO_save_64 = torch.cat((self.MO_save_64, MO_feature_resize), 0)
                self.count_64 += 1

              if self.count_64 == self.amount_features:
                torch.save(self.LE_save_64, 'LE_feature_resize_64.pt')
                torch.save(self.RE_save_64, 'RE_feature_resize_64.pt')
                torch.save(self.NO_save_64, 'NO_feature_resize_64.pt')
                torch.save(self.MO_save_64, 'MO_feature_resize_64.pt')
                print("generated features for size 64")

            #############################################
            if f_size == 32:
              if self.count_32 == 0:
                self.LE_save_32 = LE_feature_resize
                self.RE_save_32 = RE_feature_resize
                self.NO_save_32 = NO_feature_resize
                self.MO_save_32 = MO_feature_resize
                self.count_32 += 1
              else:
                self.LE_save_32 = torch.cat((self.LE_save_32, LE_feature_resize), 0)
                self.RE_save_32 = torch.cat((self.RE_save_32, RE_feature_resize), 0)
                self.NO_save_32 = torch.cat((self.NO_save_32, NO_feature_resize), 0)
                self.MO_save_32 = torch.cat((self.MO_save_32, MO_feature_resize), 0)
                self.count_32 += 1

              if self.count_32 == self.amount_features:
                torch.save(self.LE_save_32, 'LE_feature_resize_32.pt')
                torch.save(self.RE_save_32, 'RE_feature_resize_32.pt')
                torch.save(self.NO_save_32, 'NO_feature_resize_32.pt')
                torch.save(self.MO_save_32, 'MO_feature_resize_32.pt')
                print("generated features for size 32")



            LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize)
            RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize)
            NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize)
            MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize)
            
            LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm)

            LE_score = F.softmax(LE_score.view(-1),dim=0)
            LE_index = torch.argmax(LE_score)
            LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3)))

            LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature)
            LE_Att_feature = LE_Attention * LE_Swap_feature
            

            RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm)
            RE_score = F.softmax(RE_score.view(-1),dim=0)
            RE_index = torch.argmax(RE_score)
            RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3)))
            
            RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature)
            RE_Att_feature = RE_Attention * RE_Swap_feature

            NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm)
            NO_score = F.softmax(NO_score.view(-1),dim=0)
            NO_index = torch.argmax(NO_score)
            NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3)))
            
            NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature)
            NO_Att_feature = NO_Attention * NO_Swap_feature

            
            MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm)
            MO_score = F.softmax(MO_score.view(-1),dim=0)
            MO_index = torch.argmax(MO_score)
            MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3)))
            
            MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature)
            MO_Att_feature = MO_Attention * MO_Swap_feature

            update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature
            update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature
            update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature
            update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature

            UpdateVggFeatures.append(update_feature) 
        
        fea_vgg = self.MSDilate(VggFeatures[3])
        #new version
        fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3])
        # out1 = F.interpolate(fea_up0,(512,512))
        # out1 = self.to_rgb0(out1)

        fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) #
        # out2 = F.interpolate(fea_up1,(512,512))
        # out2 = self.to_rgb1(out2)

        fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) #
        # out3 = F.interpolate(fea_up2,(512,512))
        # out3 = self.to_rgb2(out3)

        fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) #
        # out4 = F.interpolate(fea_up3,(512,512))
        # out4 = self.to_rgb3(out4)

        output = self.up4(fea_up3) #
        
    
        return output  #+ out4 + out3 + out2 + out1
        #0 128 * 256 * 256
        #1 256 * 128 * 128
        #2 512 * 64 * 64
        #3 512 * 32 * 32


class UpResBlock(nn.Module):
    def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
        super(UpResBlock, self).__init__()
        self.Model = nn.Sequential(
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
            # norm_layer(dim),
            nn.LeakyReLU(0.2,True),
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
        )
    def forward(self, x):
        out = x + self.Model(x)
        return out

class VggClassNet(nn.Module):
    def __init__(self, select_layer = ['0','5','10','19']):
        super(VggClassNet, self).__init__()
        self.select = select_layer
        self.vgg = models.vgg19(pretrained=True).features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features


if __name__ == '__main__':
    print('this is network')




In [None]:
#@title custom_dataset.py (paths, disabling augmentations)
%%writefile /content/DFDNet/data/custom_dataset.py
# -- coding: utf-8 --
import os.path
import os
import random
import torchvision.transforms as transforms
import torch
from PIL import Image, ImageFilter
import numpy as np
import cv2
import math
from scipy.io import loadmat
from PIL import Image
import PIL
from torch.utils.data import Dataset, DataLoader

import glob

class AlignedDataset(Dataset):
    
    def __init__(self, root_dir, fine_size=512, transform=None):
        self.root_dir = '/content/DFDNet/ffhq'
        #self.pathes = [os.path.join(self.root_dir, x) for x in os.listdir(self.root_dir) if x[-3:] == 'jpg']
        
        self.pathes = glob.glob(self.root_dir + '/**/*.png', recursive=True)

        #print("self.pathes")
        #print(self.pathes)
        self.transform = transform
        self.fine_size = fine_size
        self.partpath = '/content/DFDNet/landmarks'
        
    def AddNoise(self,img): # noise
        if random.random() > 0.9: #
            return img
        self.sigma = np.random.randint(1, 11)
        img_tensor = torch.from_numpy(np.array(img)).float()
        noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0)

        noiseimg = torch.clamp(noise+img_tensor,0,255)
        return Image.fromarray(np.uint8(noiseimg.numpy()))

    def AddBlur(self,img): # gaussian blur or motion blur
        if random.random() > 0.9: #
            return img
        img = np.array(img)
        if random.random() > 0.35: ##gaussian blur
            blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15
            blursigma = random.randint(3, 20)
            img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10)
        else: #motion blur
            M = random.randint(1,32)
            KName = './data/MotionBlurKernel/m_%02d.mat' % M
            k = loadmat(KName)['kernel']
            k = k.astype(np.float32)
            k /= np.sum(k)
            img = cv2.filter2D(img,-1,k)
        return Image.fromarray(img)

    def AddDownSample(self,img): # downsampling
        if random.random() > 0.95: #
            return img
        sampler = random.randint(20, 100)*1.0
        img = img.resize((int(self.fine_size/sampler*10.0), int(self.fine_size/sampler*10.0)), Image.BICUBIC)
        return img

    def AddJPEG(self,img): # JPEG compression
        if random.random() > 0.6:
            return img
        imQ = random.randint(40, 80)
        img = np.array(img)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95
        _, encA = cv2.imencode('.jpg', img, encode_param)
        img = cv2.imdecode(encA,1)
        return Image.fromarray(img)

    def AddUpSample(self,img):
        return img.resize((self.fine_size, self.fine_size), Image.BICUBIC)

    def __getitem__(self, index): # indexation

        path = self.pathes[index]
        Imgs = Image.open(path).convert('RGB')
        
        A = Imgs.resize((self.fine_size, self.fine_size))
        #A = Imgs
        #A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
        C = A
        #A = self.AddBlur(A)
        
        tmps = path.split('/')
        ImgName = tmps[-1]
        part_locations = self.get_part_location(self.partpath, ImgName, 2)
        
        A = transforms.ToTensor()(A)
        C = transforms.ToTensor()(C)
        
        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 
        C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)
        
        return {'A': A, 'C': C, 'path': path, 'part_locations': part_locations}

    def get_part_location(self, landmarkpath, imgname, downscale=1):
        Landmarks = []
        with open(os.path.join(landmarkpath, imgname + '.txt'),'r') as f:
            for line in f:
                tmp = [np.float(i) for i in line.split(' ') if i != '\n']
                Landmarks.append(tmp)
        Landmarks = np.array(Landmarks)/downscale # 512 * 512
        
        Map_LE = list(np.hstack((range(17,22), range(36,42))))
        Map_RE = list(np.hstack((range(22,27), range(42,48))))
        Map_NO = list(range(29,36))
        Map_MO = list(range(48,68))
        #left eye
        Mean_LE = np.mean(Landmarks[Map_LE],0)
        L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
        Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
        #right eye
        Mean_RE = np.mean(Landmarks[Map_RE],0)
        L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
        Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
        #nose
        Mean_NO = np.mean(Landmarks[Map_NO],0)
        L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
        Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
        #mouth
        Mean_MO = np.mean(Landmarks[Map_MO],0)
        L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))

        Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
        return Location_LE, Location_RE, Location_NO, Location_MO

    def __len__(self): #
        return len(self.pathes)

    def name(self):
        return 'AlignedDataset'

In [None]:
# generating feature files
%cd /content/DFDNet
!python run.py --batchSize 1

In [None]:
#@title viewing shape of saved file
import torch
test = torch.load('/content/DFDNet/LE_feature_resize_256.pt')
print(test.shape)
print(test.is_cuda)

The loading of features is not done from ```.npy``` files anymore, but directly with torch. The generated features won't be compatible with the old code and will be loaded into a Dict with ```torch.load```.

In [None]:
#@title networks.py (train with own features, replacing numpy code)
%%writefile /content/DFDNet/models/networks.py
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.nn import Parameter as P
from util import util
from torchvision import models
import scipy.io as sio
import numpy as np
import scipy.ndimage
import torch.nn.utils.spectral_norm as SpectralNorm

from torch.autograd import Function
from math import sqrt
import random
import os
import math

from sync_batchnorm import convert_model
####

###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)

    return norm_layer


def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)

    return scheduler


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_flag=True):
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net = convert_model(net)
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    if init_flag:

        init_weights(net, init_type, gain=init_gain)

    return net


# compute adaptive instance norm
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 3)
    C, _ = size[:2]
    feat_var = feat.contiguous().view(C, -1).var(dim=1) + eps
    feat_std = feat_var.sqrt().view(C, 1, 1)
    feat_mean = feat.contiguous().view(C, -1).mean(dim=1).view(C, 1, 1)

    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):  # content_feat is degraded feature, style is ref feature
    assert (content_feat.size()[:1] == style_feat.size()[:1])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)

    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def calc_mean_std_4D(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
    # assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std_4D(style_feat)
    content_mean, content_std = calc_mean_std_4D(content_feat)
    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std + style_mean

def define_G(which_model_netG, gpu_ids=[]):
    if which_model_netG == 'UNetDictFace':
        netG = UNetDictFace(64)
        init_flag = False
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    return init_net(netG, 'normal', 0.02, gpu_ids, init_flag)


##############################################################################
# Classes
############################################################################################################################################


def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
    return nn.Sequential(
        SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
#         conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
#         nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
        SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
    )
class MSDilateBlock(nn.Module):
    def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
        super(MSDilateBlock, self).__init__()
        self.conv1 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
        self.conv2 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
        self.conv3 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
        self.conv4 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
        self.convi =  SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        cat  = torch.cat([conv1, conv2, conv3, conv4], 1)
        out = self.convi(cat) + x
        return out

##############################UNetFace#########################
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channel)

    def forward(self, input, style):
        style_mean, style_std = calc_mean_std_4D(style)
        out = self.norm(input)
        size = input.size()
        out = style_std.expand(size) * out + style_mean.expand(size)
        return out

class BlurFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        grad_input = F.conv2d(
            grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
        )
        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = F.conv2d(
            gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
        )
        return grad_input, None, None


class BlurFunction(Function):
    @staticmethod
    def forward(ctx, input, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)

        return grad_input, None, None

blur = BlurFunction.apply


class Blur(nn.Module):
    def __init__(self, channel):
        super().__init__()

        weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
        weight = weight.view(1, 1, 3, 3)
        weight = weight / weight.sum()
        weight_flip = torch.flip(weight, [2, 3])

        self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
        self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))

    def forward(self, input):
        return blur(input, self.weight, self.weight_flip)

class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()
        return weight * sqrt(2 / fan_in)
    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)

def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)
    return module

class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)
    def forward(self, input):
        return self.conv(input)

class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
    def forward(self, image, noise):
        return image + self.weight * noise

class StyledUpBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
        super().__init__()
        if upsample:
            self.conv1 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                Blur(out_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        else:
            self.conv1 = nn.Sequential(
                Blur(in_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        self.convup = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
                # Blur(out_channel),
            )
        # self.noise1 = equal_lr(NoiseInjection(out_channel))
        # self.adain1 = AdaptiveInstanceNorm(out_channel)
        self.lrelu1 = nn.LeakyReLU(0.2)

        # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        # self.noise2 = equal_lr(NoiseInjection(out_channel))
        # self.adain2 = AdaptiveInstanceNorm(out_channel)
        # self.lrelu2 = nn.LeakyReLU(0.2)

        self.ScaleModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        self.ShiftModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
            nn.Sigmoid(),
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
       
    def forward(self, input, style):
        out = self.conv1(input)
#         out = self.noise1(out, noise)
        out = self.lrelu1(out)

        Shift1 = self.ShiftModel1(style)
        Scale1 = self.ScaleModel1(style)
        out = out * Scale1 + Shift1
        # out = self.adain1(out, style)
        outup = self.convup(out)

        return outup

##############################################################################
##Face Dictionary
##############################################################################
class VGGFeat(torch.nn.Module):
    """
    Input: (B, C, H, W), RGB, [-1, 1]
    """
    def __init__(self, weight_path='./weights/vgg19.pth'):
        super().__init__()
        self.model = models.vgg19(pretrained=False)
        self.build_vgg_layers()
        
        self.model.load_state_dict(torch.load(weight_path))

        self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)))
        self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)))
        
        # self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
    
    def build_vgg_layers(self):
        vgg_pretrained_features = self.model.features
        self.features = []
        # feature_layers = [0, 3, 8, 17, 26, 35]
        feature_layers = [0, 8, 17, 26, 35]
        for i in range(len(feature_layers)-1): 
            module_layers = torch.nn.Sequential() 
            for j in range(feature_layers[i], feature_layers[i+1]):
                module_layers.add_module(str(j), vgg_pretrained_features[j])
            self.features.append(module_layers)
        self.features = torch.nn.ModuleList(self.features)

    def preprocess(self, x):
        x = (x + 1) / 2
        x = (x - self.RGB_mean) / self.RGB_std
        if x.shape[3] < 224:
            x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        return x

    def forward(self, x):
        x = self.preprocess(x)
        features = []
        for m in self.features:
            # print(m)
            x = m(x)
            features.append(x)
        return features 

def compute_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x
def ToRGB(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1))
    )

def AttentionBlock(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))
    )

class UNetDictFace(nn.Module):
    def __init__(self, ngf=64, dictionary_path='./DictionaryCenter512'):
        super().__init__()
        
        self.part_sizes = np.array([80,80,50,110]) # size for 512
        self.feature_sizes = np.array([256,128,64,32])
        self.channel_sizes = np.array([128,256,512,512])
        Parts = ['left_eye','right_eye','nose','mouth']
        self.Dict_256 = {}
        self.Dict_128 = {}
        self.Dict_64 = {}
        self.Dict_32 = {}
        """
        for j,i in enumerate(Parts):
            f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True))

            f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2)
            max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_256[i] = f_256_reshape #/ max_256

            f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True))

            f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4)
            max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_128[i] = f_128_reshape #/ max_128

            f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True))

            f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8)
            max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_64[i] = f_64_reshape #/ max_64

            f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True))

            f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16)
            max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_32[i] = f_32_reshape #/ max_32
        """
        #Parts = ['left_eye','right_eye','nose','mouth']
        self.Dict_256['left_eye'] = torch.load('/content/DFDNet/LE_feature_resize_256.pt')
        self.Dict_256['right_eye'] = torch.load('/content/DFDNet/RE_feature_resize_256.pt')
        self.Dict_256['nose'] = torch.load('/content/DFDNet/NO_feature_resize_256.pt')
        self.Dict_256['mouth'] = torch.load('/content/DFDNet/MO_feature_resize_256.pt')

        self.Dict_128['left_eye'] = torch.load('/content/DFDNet/LE_feature_resize_128.pt')
        self.Dict_128['right_eye'] = torch.load('/content/DFDNet/RE_feature_resize_128.pt')
        self.Dict_128['nose'] = torch.load('/content/DFDNet/NO_feature_resize_128.pt')
        self.Dict_128['mouth'] = torch.load('/content/DFDNet/MO_feature_resize_128.pt')

        self.Dict_64['left_eye'] = torch.load('/content/DFDNet/LE_feature_resize_64.pt')
        self.Dict_64['right_eye'] = torch.load('/content/DFDNet/RE_feature_resize_64.pt')
        self.Dict_64['nose'] = torch.load('/content/DFDNet/NO_feature_resize_64.pt')
        self.Dict_64['mouth'] = torch.load('/content/DFDNet/MO_feature_resize_64.pt')

        self.Dict_32['left_eye'] = torch.load('/content/DFDNet/LE_feature_resize_32.pt')
        self.Dict_32['right_eye'] = torch.load('/content/DFDNet/RE_feature_resize_32.pt')
        self.Dict_32['nose'] = torch.load('/content/DFDNet/NO_feature_resize_32.pt')
        self.Dict_32['mouth'] = torch.load('/content/DFDNet/MO_feature_resize_32.pt')



        """
        print("self.Dict_256")
        print(len(self.Dict_256))
        print("self.Dict_256['left_eye'].shape")
        print(self.Dict_256['left_eye'].shape)
        print("self.Dict_256['right_eye'].shape")
        print(self.Dict_256['right_eye'].shape)
        print("self.Dict_256['nose'].shape")
        print(self.Dict_256['nose'].shape)
        print("self.Dict_256['mouth'].shape")
        print(self.Dict_256['mouth'].shape)
        print("-------------------")
        print("self.Dict_128")
        print(len(self.Dict_128))
        print("self.Dict_128['left_eye'].shape")
        print(self.Dict_128['left_eye'].shape)
        print("self.Dict_128['right_eye'].shape")
        print(self.Dict_128['right_eye'].shape)
        print("self.Dict_128['nose'].shape")
        print(self.Dict_128['nose'].shape)
        print("self.Dict_128['mouth'].shape")
        print(self.Dict_128['mouth'].shape)
        print("-------------------")
        print("self.Dict_64")
        print(len(self.Dict_64))
        print("self.Dict_64['left_eye'].shape")
        print(self.Dict_64['left_eye'].shape)
        print("self.Dict_64['right_eye'].shape")
        print(self.Dict_64['right_eye'].shape)
        print("self.Dict_64['nose'].shape")
        print(self.Dict_64['nose'].shape)
        print("self.Dict_64['mouth'].shape")
        print(self.Dict_64['mouth'].shape)
        print("-------------------")
        print("self.Dict_32")
        print(len(self.Dict_32))
        print("self.Dict_32['left_eye'].shape")
        print(self.Dict_32['left_eye'].shape)
        print("self.Dict_32['right_eye'].shape")
        print(self.Dict_32['right_eye'].shape)
        print("self.Dict_32['nose'].shape")
        print(self.Dict_32['nose'].shape)
        print("self.Dict_32['mouth'].shape")
        print(self.Dict_32['mouth'].shape)
        """

        self.le_256 = AttentionBlock(128)
        self.le_128 = AttentionBlock(256)
        self.le_64 = AttentionBlock(512)
        self.le_32 = AttentionBlock(512)

        self.re_256 = AttentionBlock(128)
        self.re_128 = AttentionBlock(256)
        self.re_64 = AttentionBlock(512)
        self.re_32 = AttentionBlock(512)

        self.no_256 = AttentionBlock(128)
        self.no_128 = AttentionBlock(256)
        self.no_64 = AttentionBlock(512)
        self.no_32 = AttentionBlock(512)

        self.mo_256 = AttentionBlock(128)
        self.mo_128 = AttentionBlock(256)
        self.mo_64 = AttentionBlock(512)
        self.mo_32 = AttentionBlock(512)

        #norm
        self.VggExtract = VGGFeat()
        
        ######################
        self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1])  #

        self.up0 = StyledUpBlock(ngf*8,ngf*8)
        self.up1 = StyledUpBlock(ngf*8, ngf*4) #
        self.up2 = StyledUpBlock(ngf*4, ngf*2) #
        self.up3 = StyledUpBlock(ngf*2, ngf) #
        self.up4 = nn.Sequential( # 128
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
            # nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            UpResBlock(ngf),
            UpResBlock(ngf),
            # SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
            nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        self.to_rgb0 = ToRGB(ngf*8)
        self.to_rgb1 = ToRGB(ngf*4)
        self.to_rgb2 = ToRGB(ngf*2)
        self.to_rgb3 = ToRGB(ngf*1)

        # for param in self.BlurInputConv.parameters():
        #     param.requires_grad = False
    


        self.count_256 = 0
        self.count_128 = 0
        self.count_64 = 0
        self.count_32 = 0

    def forward(self, input, part_locations):
        #print("input.shape")
        #print(input.shape)
        VggFeatures = self.VggExtract(input) #VggFeatures = list object
        # for b in range(input.size(0)):
        b = 0
        UpdateVggFeatures = []
        for i, f_size in enumerate(self.feature_sizes):
            cur_feature = VggFeatures[i]
            #print("cur_feature.shape")
            #print(cur_feature.shape)

            update_feature = cur_feature.clone() #* 0
            cur_part_sizes = self.part_sizes // (512/f_size)
            dicts_feature = getattr(self, 'Dict_'+str(f_size))
            
            LE_Dict_feature = dicts_feature['left_eye'].to(input)
            RE_Dict_feature = dicts_feature['right_eye'].to(input)
            NO_Dict_feature = dicts_feature['nose'].to(input)
            MO_Dict_feature = dicts_feature['mouth'].to(input)

            le_location = (part_locations[0][b] // (512/f_size)).int()
            re_location = (part_locations[1][b] // (512/f_size)).int()
            no_location = (part_locations[2][b] // (512/f_size)).int()
            mo_location = (part_locations[3][b] // (512/f_size)).int()

            LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone()
            RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone()
            NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone()
            MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone()
            
            #resize
            LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            
            #print("LE_feature_resize.shape")
            #print(LE_feature_resize.shape)
            
            #print("f_size")
            #print(f_size)
            """
            if f_size == 256:
              if self.count_256 == 0:
                #print("LE_save_256 = cur_feature")
                self.LE_save_256 = LE_feature_resize
                self.RE_save_256 = RE_feature_resize
                self.NO_save_256 = NO_feature_resize
                self.MO_save_256 = MO_feature_resize
                self.count_256 += 1
              else:
                #print("torch.cat((LE_save_256, LE_feature_resize), 1)")
                self.LE_save_256 = torch.cat((self.LE_save_256, LE_feature_resize), 0)
                self.RE_save_256 = torch.cat((self.RE_save_256, RE_feature_resize), 0)
                self.NO_save_256 = torch.cat((self.NO_save_256, NO_feature_resize), 0)
                self.MO_save_256 = torch.cat((self.MO_save_256, MO_feature_resize), 0)
                self.count_256 += 1

              if self.count_256 == 20:
                torch.save(self.LE_save_256, 'LE_feature_resize_256.pt')
                torch.save(self.RE_save_256, 'RE_feature_resize_256.pt')
                torch.save(self.NO_save_256, 'NO_feature_resize_256.pt')
                torch.save(self.MO_save_256, 'MO_feature_resize_256.pt')
              
            #############################################
            if f_size == 128:
              if self.count_128 == 0:
                self.LE_save_128 = LE_feature_resize
                self.RE_save_128 = RE_feature_resize
                self.NO_save_128 = NO_feature_resize
                self.MO_save_128 = MO_feature_resize
                self.count_128 += 1
              else:
                self.LE_save_128 = torch.cat((self.LE_save_128, LE_feature_resize), 0)
                self.RE_save_128 = torch.cat((self.RE_save_128, RE_feature_resize), 0)
                self.NO_save_128 = torch.cat((self.NO_save_128, NO_feature_resize), 0)
                self.MO_save_128 = torch.cat((self.MO_save_128, MO_feature_resize), 0)
                self.count_256 += 1

              if self.count_128 == 20:
                torch.save(self.LE_save_128, 'LE_feature_resize_128.pt')
                torch.save(self.RE_save_128, 'RE_feature_resize_128.pt')
                torch.save(self.NO_save_128, 'NO_feature_resize_128.pt')
                torch.save(self.MO_save_128, 'MO_feature_resize_128.pt')
            #############################################
            if f_size == 64:
              if self.count_64 == 0:
                self.LE_save_64 = LE_feature_resize
                self.RE_save_64 = RE_feature_resize
                self.NO_save_64 = NO_feature_resize
                self.MO_save_64 = MO_feature_resize
                self.count_64 += 1
              else:
                self.LE_save_64 = torch.cat((self.LE_save_64, LE_feature_resize), 0)
                self.RE_save_64 = torch.cat((self.RE_save_64, RE_feature_resize), 0)
                self.NO_save_64 = torch.cat((self.NO_save_64, NO_feature_resize), 0)
                self.MO_save_64 = torch.cat((self.MO_save_64, MO_feature_resize), 0)
                self.count_256 += 1

              if self.count_64 == 20:
                torch.save(self.LE_save_64, 'LE_feature_resize_64.pt')
                torch.save(self.RE_save_64, 'RE_feature_resize_64.pt')
                torch.save(self.NO_save_64, 'NO_feature_resize_64.pt')
                torch.save(self.MO_save_64, 'MO_feature_resize_64.pt')

            #############################################
            if f_size == 32:
              if self.count_32 == 0:
                self.LE_save_32 = LE_feature_resize
                self.RE_save_32 = RE_feature_resize
                self.NO_save_32 = NO_feature_resize
                self.MO_save_32 = MO_feature_resize
                self.count_32 += 1
              else:
                self.LE_save_32 = torch.cat((self.LE_save_32, LE_feature_resize), 0)
                self.RE_save_32 = torch.cat((self.RE_save_32, RE_feature_resize), 0)
                self.NO_save_32 = torch.cat((self.NO_save_32, NO_feature_resize), 0)
                self.MO_save_32 = torch.cat((self.MO_save_32, MO_feature_resize), 0)
                self.count_256 += 1

              if self.count_32 == 20:
                torch.save(self.LE_save_32, 'LE_feature_resize_32.pt')
                torch.save(self.RE_save_32, 'RE_feature_resize_32.pt')
                torch.save(self.NO_save_32, 'NO_feature_resize_32.pt')
                torch.save(self.MO_save_32, 'MO_feature_resize_32.pt')

            if self.count_256 == 20:
              print("features generated")
            """
            LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize)
            RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize)
            NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize)
            MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize)
            
            LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm)

            LE_score = F.softmax(LE_score.view(-1),dim=0)
            LE_index = torch.argmax(LE_score)
            LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3)))

            LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature)
            LE_Att_feature = LE_Attention * LE_Swap_feature
            

            RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm)
            RE_score = F.softmax(RE_score.view(-1),dim=0)
            RE_index = torch.argmax(RE_score)
            RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3)))
            
            RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature)
            RE_Att_feature = RE_Attention * RE_Swap_feature

            NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm)
            NO_score = F.softmax(NO_score.view(-1),dim=0)
            NO_index = torch.argmax(NO_score)
            NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3)))
            
            NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature)
            NO_Att_feature = NO_Attention * NO_Swap_feature

            
            MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm)
            MO_score = F.softmax(MO_score.view(-1),dim=0)
            MO_index = torch.argmax(MO_score)
            MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3)))
            
            MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature)
            MO_Att_feature = MO_Attention * MO_Swap_feature

            update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature
            update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature
            update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature
            update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature

            UpdateVggFeatures.append(update_feature) 
        
        fea_vgg = self.MSDilate(VggFeatures[3])
        #new version
        fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3])
        # out1 = F.interpolate(fea_up0,(512,512))
        # out1 = self.to_rgb0(out1)

        fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) #
        # out2 = F.interpolate(fea_up1,(512,512))
        # out2 = self.to_rgb1(out2)

        fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) #
        # out3 = F.interpolate(fea_up2,(512,512))
        # out3 = self.to_rgb2(out3)

        fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) #
        # out4 = F.interpolate(fea_up3,(512,512))
        # out4 = self.to_rgb3(out4)

        output = self.up4(fea_up3) #
        
    
        return output  #+ out4 + out3 + out2 + out1
        #0 128 * 256 * 256
        #1 256 * 128 * 128
        #2 512 * 64 * 64
        #3 512 * 32 * 32


class UpResBlock(nn.Module):
    def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
        super(UpResBlock, self).__init__()
        self.Model = nn.Sequential(
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
            # norm_layer(dim),
            nn.LeakyReLU(0.2,True),
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
        )
    def forward(self, x):
        out = x + self.Model(x)
        return out

class VggClassNet(nn.Module):
    def __init__(self, select_layer = ['0','5','10','19']):
        super(VggClassNet, self).__init__()
        self.select = select_layer
        self.vgg = models.vgg19(pretrained=True).features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features


if __name__ == '__main__':
    print('this is network')




In [None]:
#@title custom_dataset.py (adding augmentations back)
%%writefile /content/DFDNet/data/custom_dataset.py
# -- coding: utf-8 --
import os.path
import os
import random
import torchvision.transforms as transforms
import torch
from PIL import Image, ImageFilter
import numpy as np
import cv2
import math
from scipy.io import loadmat
from PIL import Image
import PIL
from torch.utils.data import Dataset, DataLoader

import glob

class AlignedDataset(Dataset):
    
    def __init__(self, root_dir, fine_size=512, transform=None):
        self.root_dir = '/content/DFDNet/ffhq'
        #self.pathes = [os.path.join(self.root_dir, x) for x in os.listdir(self.root_dir) if x[-3:] == 'jpg']
        
        self.pathes = glob.glob(self.root_dir + '/**/*.png', recursive=True)

        #print("self.pathes")
        #print(self.pathes)
        self.transform = transform
        self.fine_size = fine_size
        self.partpath = '/content/DFDNet/landmarks'
        
    def AddNoise(self,img): # noise
        if random.random() > 0.9: #
            return img
        self.sigma = np.random.randint(1, 11)
        img_tensor = torch.from_numpy(np.array(img)).float()
        noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0)

        noiseimg = torch.clamp(noise+img_tensor,0,255)
        return Image.fromarray(np.uint8(noiseimg.numpy()))

    def AddBlur(self,img): # gaussian blur or motion blur
        if random.random() > 0.9: #
            return img
        img = np.array(img)
        if random.random() > 0.35: ##gaussian blur
            blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15
            blursigma = random.randint(3, 20)
            img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10)
        else: #motion blur
            M = random.randint(1,32)
            KName = './data/MotionBlurKernel/m_%02d.mat' % M
            k = loadmat(KName)['kernel']
            k = k.astype(np.float32)
            k /= np.sum(k)
            img = cv2.filter2D(img,-1,k)
        return Image.fromarray(img)

    def AddDownSample(self,img): # downsampling
        if random.random() > 0.95: #
            return img
        sampler = random.randint(20, 100)*1.0
        img = img.resize((int(self.fine_size/sampler*10.0), int(self.fine_size/sampler*10.0)), Image.BICUBIC)
        return img

    def AddJPEG(self,img): # JPEG compression
        if random.random() > 0.6:
            return img
        imQ = random.randint(40, 80)
        img = np.array(img)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95
        _, encA = cv2.imencode('.jpg', img, encode_param)
        img = cv2.imdecode(encA,1)
        return Image.fromarray(img)

    def AddUpSample(self,img):
        return img.resize((self.fine_size, self.fine_size), Image.BICUBIC)

    def __getitem__(self, index): # indexation

        path = self.pathes[index]
        Imgs = Image.open(path).convert('RGB')
        
        A = Imgs.resize((self.fine_size, self.fine_size))
        A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
        C = A
        A = self.AddBlur(A)
        
        tmps = path.split('/')
        ImgName = tmps[-1]
        part_locations = self.get_part_location(self.partpath, ImgName, 2)
        
        A = transforms.ToTensor()(A)
        C = transforms.ToTensor()(C)
        
        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 
        C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)
        
        return {'A': A, 'C': C, 'path': path, 'part_locations': part_locations}

    def get_part_location(self, landmarkpath, imgname, downscale=1):
        Landmarks = []
        with open(os.path.join(landmarkpath, imgname + '.txt'),'r') as f:
            for line in f:
                tmp = [np.float(i) for i in line.split(' ') if i != '\n']
                Landmarks.append(tmp)
        Landmarks = np.array(Landmarks)/downscale # 512 * 512
        
        Map_LE = list(np.hstack((range(17,22), range(36,42))))
        Map_RE = list(np.hstack((range(22,27), range(42,48))))
        Map_NO = list(range(29,36))
        Map_MO = list(range(48,68))
        #left eye
        Mean_LE = np.mean(Landmarks[Map_LE],0)
        L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
        Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
        #right eye
        Mean_RE = np.mean(Landmarks[Map_RE],0)
        L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
        Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
        #nose
        Mean_NO = np.mean(Landmarks[Map_NO],0)
        L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
        Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
        #mouth
        Mean_MO = np.mean(Landmarks[Map_MO],0)
        L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))

        Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
        return Location_LE, Location_RE, Location_NO, Location_MO

    def __len__(self): #
        return len(self.pathes)

    def name(self):
        return 'AlignedDataset'

In [None]:
# training with extracted features
%cd /content/DFDNet
!python run.py --batchSize 1

Now you can use the above code from ``Testing``, but **don't** change ``networks.py`` and avoid using the original version of that file. Stick to ``networks.py (train with own features, replacing numpy code)``. You need to load the custom features with the modified code.