# Colab-DFDNet-lightning

Official repo: [csxmli2016/DFDNet](https://github.com/csxmli2016/DFDNet)

Original repo: [max-vasyuk/DFDNet](https://github.com/max-vasyuk/DFDNet)

My fork: [styler00dollar/Colab-DFDNet](https://github.com/styler00dollar/Colab-DFDNet)

Porting everything to pytorch lightning. Warning: Currently only with MSE loss!

In [None]:
!nvidia-smi

In [None]:
#@title GPU
!pip install pytorch-lightning -U
!git clone https://github.com/max-vasyuk/DFDNet
%cd /content/DFDNet
!pip install -r requirements.txt
!pip install pytorch-msssim
!pip install trains
!pip install PyJWT==1.7.1
!pip install tensorboardX

In [None]:
#@title TPU  (restart runtime afterwards)
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
#!pip install pytorch-lightning
!pip install lightning-flash

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()

!pip install pytorch-lightning


!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev > /dev/null
!pip install pytorch-lightning > /dev/null


!git clone https://github.com/max-vasyuk/DFDNet
%cd /content/DFDNet
!pip install -r requirements.txt
!pip install pytorch-msssim
!pip install trains
!pip install PyJWT==1.7.1
!pip install tensorboardX

In [None]:
#@title download data
!mkdir /content/DFDNet/DictionaryCenter512
%cd /content/DFDNet/DictionaryCenter512
!gdown --id 1sEB9j3s7Wj9aqPai1NF-MR7B-c0zfTin
!gdown --id 1H4kByBiVmZuS9TbrWUR5uSNY770Goid6
!gdown --id 10ctK3d9znZ9nGN3d1Z77xW3GGshbeKBb
!gdown --id 1gcwmrIZjPFVu-cHjdQD6P4luohkPsil-
!gdown --id 1rJ8cORPxbJsIVAiNrjBag0ihaY_Mvurn
!gdown --id 1LkfJv2a3ud-mefAc1eZMJuINuNdSYgYO
!gdown --id 1LH-nxD__icSJvTiAbXAXDch03oDtbpkZ
!gdown --id 1JRTStLFsQ8dwaQjQ8qG5fNyrOvo6Tcvd
!gdown --id 1Z4AkU1pOYTYpdbfljCgNMmPilhdEd0Kl
!gdown --id 1Z4e1ltB3ACbYKzkoMBuVtzZ7a310G4xc
!gdown --id 1fqWmi6-8ZQzUtZTp9UH4hyom7n4nl8aZ
!gdown --id 1wfHtsExLvSgfH_EWtCPjTF5xsw3YyvjC
!gdown --id 1Jr3Luf6tmcdKANcSLzvt0sjXr0QUIQ2g
!gdown --id 1sPd4_IMYgqGLol0gqhHjBedKKxFAxswR
!gdown --id 1eVFjXJRnBH4mx7ZbAmZRwVXZNUbgCQec
!gdown --id 1w0GfO_KY775ZVF3KMk74ya6QL_bNU4cJ
%cd /content/DFDNet
!gdown --id 1VE5tnOKcfL6MoV839IVCCw5FhJxIgml5
!7z x data.zip
!mkdir /content/DFDNet/weights/
%cd /content/DFDNet/weights/
!gdown --id 1SfKKZJduOGhDD27Xl01yDx0-YSEkL2Aa

Preconfigured paths:
```
/content/DFDNet/DictionaryCenter512 (numpy dictionary files)
/content/DFDNet/ffhq (dataset path)
/content/DFDNet/landmarks (landmarks for that dataset)
/content/DFDNet/weights/vgg19.pth (vgg19 path)

/content/validation (validation path)
/content/landmarks (landmark path for validation)
```

In [None]:
#@title init.py
import torch.nn.init as init

def weights_init(net, init_type = 'kaiming', init_gain = 0.02):
    #Initialize network weights.
    #Parameters:
    #    net (network)       -- network to be initialized
    #    init_type (str)     -- the name of an initialization method: normal | xavier | kaiming | orthogonal
    #    init_var (float)    -- scaling factor for normal, xavier and orthogonal.

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # Apply the initialization function <init_func>
    print('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)

In [None]:
#@title loss.py
import torch.nn as nn
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

class TVLoss(pl.LightningModule):
    def __init__(self,TVLoss_weight=1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]


class hinge_loss(pl.LightningModule):
    def __init__(self):
        super(hinge_loss, self).__init__()

    def forward(self, dis_fake, dis_real):
        loss_real = torch.mean(F.relu(1. - dis_real))
        loss_fake = torch.mean(F.relu(1. + dis_fake))
        return loss_real + loss_fake


class hinge_loss_G(pl.LightningModule):
    def __init__(self):
        super(hinge_loss_G, self).__init__()

    def forward(self, dis_fake):
        loss_fake = -torch.mean(dis_fake)
        return loss_fake

In [None]:
#@title checkpoint.py
#https://github.com/PyTorchLightning/pytorch-lightning/issues/2534
import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="Checkpoint",
        use_modelcheckpoint_filename=False,
        save_path = '/content/'
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
        self.save_path = save_path

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch}_{global_step}.ckpt"
            #ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            ckpt_path = os.path.join(self.save_path, filename)
            trainer.save_checkpoint(ckpt_path)
            # run validation once checkpoint was made
            trainer.run_evaluation()

#Trainer(callbacks=[CheckpointEveryNSteps()])

Edit motion blur kernel path inside ``data``.

In [None]:
#@title data.py
import os.path
import os
import random
import torchvision.transforms as transforms
import torch
from PIL import Image, ImageFilter
import numpy as np
import cv2
import math
from scipy.io import loadmat
from PIL import Image
import PIL
from torch.utils.data import Dataset, DataLoader

class DS(Dataset):
    
    def __init__(self, root_dir, fine_size=512, transform=None, partpath=None):
        self.root_dir = root_dir
        self.pathes = [os.path.join(self.root_dir, x) for x in os.listdir(self.root_dir) if x[-3:] == 'png']
        self.transform = transform
        self.fine_size = fine_size
        self.partpath = partpath
        
    def AddNoise(self,img): # noise
        if random.random() > 0.9: #
            return img
        self.sigma = np.random.randint(1, 11)
        img_tensor = torch.from_numpy(np.array(img)).float()
        noise = torch.randn(img_tensor.size()).mul_(self.sigma/1.0)

        noiseimg = torch.clamp(noise+img_tensor,0,255)
        return Image.fromarray(np.uint8(noiseimg.numpy()))

    def AddBlur(self,img): # gaussian blur or motion blur
        if random.random() > 0.9: #
            return img
        img = np.array(img)
        if random.random() > 0.35: ##gaussian blur
            blursize = random.randint(1,17) * 2 + 1 ##3,5,7,9,11,13,15
            blursigma = random.randint(3, 20)
            img = cv2.GaussianBlur(img, (blursize,blursize), blursigma/10)
        else: #motion blur
            M = random.randint(1,32)
            KName = '/content/DFDNet/data/MotionBlurKernel/m_%02d.mat' % M
            k = loadmat(KName)['kernel']
            k = k.astype(np.float32)
            k /= np.sum(k)
            img = cv2.filter2D(img,-1,k)
        return Image.fromarray(img)

    def AddDownSample(self,img): # downsampling
        if random.random() > 0.95: #
            return img
        sampler = random.randint(20, 100)*1.0
        img = img.resize((int(self.fine_size/sampler*10.0), int(self.fine_size/sampler*10.0)), Image.BICUBIC)
        return img

    def AddJPEG(self,img): # JPEG compression
        if random.random() > 0.6:
            return img
        imQ = random.randint(40, 80)
        img = np.array(img)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),imQ] # (0,100),higher is better,default is 95
        _, encA = cv2.imencode('.jpg', img, encode_param)
        img = cv2.imdecode(encA,1)
        return Image.fromarray(img)

    def AddUpSample(self,img):
        return img.resize((self.fine_size, self.fine_size), Image.BICUBIC)

    def __getitem__(self, index): # indexation

        path = self.pathes[index]
        Imgs = Image.open(path).convert('RGB')
        
        #A = Imgs.resize((self.fine_size, self.fine_size))
        if self.transform:
          A = self.transform(Imgs)

        A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
        C = A
        A = self.AddBlur(A)
        A = self.AddJPEG(A)
        
        tmps = path.split('/')
        ImgName = tmps[-1]
        part_locations = self.get_part_location(self.partpath, ImgName, 2)
        
        A = transforms.ToTensor()(A)
        C = transforms.ToTensor()(C)
        
        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 
        C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)

        #A = transforms.ToTensor()(A)
        #C = transforms.ToTensor()(C)

        #return {'A': A, 'C': C, 'path': path, 'part_locations': part_locations}
        return A, C, path, part_locations

    def get_part_location(self, landmarkpath, imgname, downscale=1):
        Landmarks = []
        with open(os.path.join(landmarkpath, imgname + '.txt'),'r') as f:
            for line in f:
                tmp = [np.float(i) for i in line.split(' ') if i != '\n']
                Landmarks.append(tmp)
        Landmarks = np.array(Landmarks)/downscale # 512 * 512
        
        Map_LE = list(np.hstack((range(17,22), range(36,42))))
        Map_RE = list(np.hstack((range(22,27), range(42,48))))
        Map_NO = list(range(29,36))
        Map_MO = list(range(48,68))
        #left eye
        Mean_LE = np.mean(Landmarks[Map_LE],0)
        L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
        Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
        #right eye
        Mean_RE = np.mean(Landmarks[Map_RE],0)
        L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
        Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
        #nose
        Mean_NO = np.mean(Landmarks[Map_NO],0)
        L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
        Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
        #mouth
        Mean_MO = np.mean(Landmarks[Map_MO],0)
        L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))

        Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
        return Location_LE, Location_RE, Location_NO, Location_MO

    def __len__(self): #
        return len(self.pathes)

    def name(self):
        return 'AlignedDataset'




class DS_val(Dataset):
    
    def __init__(self, root_dir, fine_size=512, transform=None, partpath=None):
        self.root_dir = root_dir
        self.pathes = [os.path.join(self.root_dir, x) for x in os.listdir(self.root_dir) if x[-3:] == 'png']
        self.transform = transform
        self.fine_size = fine_size
        self.partpath = partpath
        

    def __getitem__(self, index): # indexation

        path = self.pathes[index]
        Imgs = Image.open(path).convert('RGB')
        
        A = Imgs
        #A = Imgs.resize((self.fine_size, self.fine_size))
        #A = transforms.ColorJitter(0.3, 0.3, 0.3, 0)(A)
        #C = A
        #A = self.AddBlur(A)
        #A = self.AddJPEG(A)
        
        tmps = path.split('/')
        ImgName = tmps[-1]
        part_locations = self.get_part_location(self.partpath, ImgName, 2)
        
        A = transforms.ToTensor()(A)
        #C = transforms.ToTensor()(C)
        
        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 
        #C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C)
        
        #return {'A': A, 'path': path, 'part_locations': part_locations}
        return A, path, part_locations

    def get_part_location(self, landmarkpath, imgname, downscale=1):
        Landmarks = []
        with open(os.path.join(landmarkpath, imgname + '.txt'),'r') as f:
            for line in f:
                tmp = [np.float(i) for i in line.split(' ') if i != '\n']
                Landmarks.append(tmp)
        Landmarks = np.array(Landmarks)/downscale # 512 * 512
        
        Map_LE = list(np.hstack((range(17,22), range(36,42))))
        Map_RE = list(np.hstack((range(22,27), range(42,48))))
        Map_NO = list(range(29,36))
        Map_MO = list(range(48,68))
        #left eye
        Mean_LE = np.mean(Landmarks[Map_LE],0)
        L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
        Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
        #right eye
        Mean_RE = np.mean(Landmarks[Map_RE],0)
        L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
        Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
        #nose
        Mean_NO = np.mean(Landmarks[Map_NO],0)
        L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
        Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
        #mouth
        Mean_MO = np.mean(Landmarks[Map_MO],0)
        L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))

        Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
        return Location_LE, Location_RE, Location_NO, Location_MO

    def __len__(self): #
        return len(self.pathes)

    def name(self):
        return 'ValDataset'

In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', train_partpath: str = './', validation_path: str = './', val_partpath: str = './', test_path: str = './', batch_size: int = 5, num_workers: int = 2):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = 512


        self.train_partpath = train_partpath
        self.val_partpath = val_partpath

    def setup(self, stage=None):
        transform = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.CenterCrop(size=self.size),
            #transforms.RandomHorizontalFlip()
            #transforms.ToTensor()
        ])
        
        self.DFDNetdataset_train = DS(root_dir=self.training_dir, transform=transform, fine_size=self.size, partpath=self.train_partpath)
        self.DFDNetdataset_validation = DS_val(root_dir=self.validation_dir, transform=transform, partpath=self.val_partpath)
        self.DFDNetdataset_test = DS_val(root_dir=self.validation_dir, transform=transform, partpath='/content/landmarks')

    def train_dataloader(self):
        return DataLoader(self.DFDNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFDNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFDNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
#@title CustomTrainClass.py
import pytorch_lightning as pl
from torchvision.utils import save_image

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()

    self.netG = UNetDictFace(64)

    #init_net(self.netG, 'normal', 0.02, gpu_ids, init_flag)
    weights_init(self.netG, 'kaiming')



    # Establish convention for real and fake labels during training
    self.real_label = 1.
    self.fake_label = 0.

    self.criterionG = torch.nn.MSELoss()
    #criterionD = torch.nn.BCELoss()

    #self.hinge_G = hinge_loss_G()


  def forward(self, ):
    return self.netG()

  def training_step(self, train_batch, batch_idx):

      #return {'A': A, 'C': C, 'path': path, 'part_locations': part_locations}
      # train_batch[0] = A
      # train_batch[1] = C
      # train_batch[3] = part_locations

      data_a = train_batch[0] #.to(device)
      data_c = train_batch[1] #.to(device)

      data_part_locations = train_batch[2]

      ####################################
      # generator training
      fake = self.netG(data_a, part_locations=train_batch[3])
      mse_loss = self.criterionG(fake, data_c)

      return mse_loss




  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.netG.parameters(), lr=2e-4, betas=(0.5, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    return optimizer

  def validation_step(self, train_batch, train_idx):
    # return {'A': A, 'path': path, 'part_locations': part_locations}
    # train_batch[0] = lr
    # train_batch[1] = path
    # train_batch[2] = part_locations
    out = self.netG(train_batch[0], part_locations=train_batch[2])

    validation_output = '/content/validation_output/'

    # train_batch[3] can contain multiple files, depending on the batch_size
    for f in train_batch[1]:
      # data is processed as a batch, to save indididual files, a counter is used
      counter = 0
      if not os.path.exists(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0])):
        os.makedirs(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0]))

      filename_with_extention = os.path.basename(f)
      filename = os.path.splitext(filename_with_extention)[0]
      save_image(out[counter], os.path.join(validation_output, filename, str(self.trainer.global_step) + '.png'))


  def test_step(self, train_batch, train_idx):
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path
    print("testing")
    test_output = '/content/test_output/'
    if not os.path.exists(test_output):
      os.makedirs(test_output)

    out = self(train_batch[0].unsqueeze(0),train_batch[1].unsqueeze(0))
    out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    save_image(out, os.path.join(test_output, os.path.splitext(os.path.basename(train_batch[2]))[0] + '.png'))



Edit ``vgg path`` and ``dictionary_path`` inside ``model``.

In [None]:
#@title model.py

import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.nn import Parameter as P
#from util import util
from torchvision import models
#import scipy.io as sio
import numpy as np
#import scipy.ndimage
import torch.nn.utils.spectral_norm as SpectralNorm

from torch.autograd import Function
from math import sqrt
import random
import os
import math

import pytorch_lightning as pl


class VGGFeat(pl.LightningModule):
    """
    Input: (B, C, H, W), RGB, [-1, 1]
    """
    def __init__(self, weight_path='/content/DFDNet/weights/vgg19.pth'):
        super().__init__()
        self.model = models.vgg19(pretrained=False)
        self.build_vgg_layers()
        
        self.model.load_state_dict(torch.load(weight_path))

        self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)))
        self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)))
        
        # self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
    
    def build_vgg_layers(self):
        vgg_pretrained_features = self.model.features
        self.features = []
        # feature_layers = [0, 3, 8, 17, 26, 35]
        feature_layers = [0, 8, 17, 26, 35]
        for i in range(len(feature_layers)-1): 
            module_layers = torch.nn.Sequential() 
            for j in range(feature_layers[i], feature_layers[i+1]):
                module_layers.add_module(str(j), vgg_pretrained_features[j])
            self.features.append(module_layers)
        self.features = torch.nn.ModuleList(self.features)

    def preprocess(self, x):
        x = (x + 1) / 2
        x = (x - self.RGB_mean) / self.RGB_std
        if x.shape[3] < 224:
            x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        return x

    def forward(self, x):
        x = self.preprocess(x)
        features = []
        for m in self.features:
            # print(m)
            x = m(x)
            features.append(x)
        return features 


class StyledUpBlock(pl.LightningModule):
    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
        super().__init__()
        if upsample:
            self.conv1 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                Blur(out_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        else:
            self.conv1 = nn.Sequential(
                Blur(in_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        self.convup = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
                # Blur(out_channel),
            )
        # self.noise1 = equal_lr(NoiseInjection(out_channel))
        # self.adain1 = AdaptiveInstanceNorm(out_channel)
        self.lrelu1 = nn.LeakyReLU(0.2)

        # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        # self.noise2 = equal_lr(NoiseInjection(out_channel))
        # self.adain2 = AdaptiveInstanceNorm(out_channel)
        # self.lrelu2 = nn.LeakyReLU(0.2)

        self.ScaleModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        self.ShiftModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
            nn.Sigmoid(),
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
       
    def forward(self, input, style):
        out = self.conv1(input)
#         out = self.noise1(out, noise)
        out = self.lrelu1(out)

        Shift1 = self.ShiftModel1(style)
        Scale1 = self.ScaleModel1(style)
        out = out * Scale1 + Shift1
        # out = self.adain1(out, style)
        outup = self.convup(out)

        return outup





class StyledUpBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
        super().__init__()
        if upsample:
            self.conv1 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                Blur(out_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        else:
            self.conv1 = nn.Sequential(
                Blur(in_channel),
                # EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
                SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
            )
        self.convup = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
                SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
                nn.LeakyReLU(0.2),
                # Blur(out_channel),
            )
        # self.noise1 = equal_lr(NoiseInjection(out_channel))
        # self.adain1 = AdaptiveInstanceNorm(out_channel)
        self.lrelu1 = nn.LeakyReLU(0.2)

        # self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        # self.noise2 = equal_lr(NoiseInjection(out_channel))
        # self.adain2 = AdaptiveInstanceNorm(out_channel)
        # self.lrelu2 = nn.LeakyReLU(0.2)

        self.ScaleModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        self.ShiftModel1 = nn.Sequential(
            # Blur(in_channel),
            SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
            # nn.Conv2d(in_channel,out_channel,3, 1, 1),
            nn.LeakyReLU(0.2, True),
            SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
            nn.Sigmoid(),
            # nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
       
    def forward(self, input, style):
        out = self.conv1(input)
#         out = self.noise1(out, noise)
        out = self.lrelu1(out)

        Shift1 = self.ShiftModel1(style)
        Scale1 = self.ScaleModel1(style)
        out = out * Scale1 + Shift1
        # out = self.adain1(out, style)
        outup = self.convup(out)

        return outup

class UNetDictFace(pl.LightningModule):
    def __init__(self, ngf=64, dictionary_path='/content/DFDNet/DictionaryCenter512'):
        super().__init__()
        
        self.part_sizes = np.array([80,80,50,110]) # size for 512
        self.feature_sizes = np.array([256,128,64,32])
        self.channel_sizes = np.array([128,256,512,512])
        Parts = ['left_eye','right_eye','nose','mouth']
        self.Dict_256 = {}
        self.Dict_128 = {}
        self.Dict_64 = {}
        self.Dict_32 = {}
        for j,i in enumerate(Parts):
            f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True))

            f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2)
            max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_256[i] = f_256_reshape #/ max_256
            
            f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True))

            f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4)
            max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_128[i] = f_128_reshape #/ max_128

            f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True))

            f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8)
            max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_64[i] = f_64_reshape #/ max_64

            f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True))

            f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16)
            max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
            self.Dict_32[i] = f_32_reshape #/ max_32

        self.le_256 = AttentionBlock(128)
        self.le_128 = AttentionBlock(256)
        self.le_64 = AttentionBlock(512)
        self.le_32 = AttentionBlock(512)

        self.re_256 = AttentionBlock(128)
        self.re_128 = AttentionBlock(256)
        self.re_64 = AttentionBlock(512)
        self.re_32 = AttentionBlock(512)

        self.no_256 = AttentionBlock(128)
        self.no_128 = AttentionBlock(256)
        self.no_64 = AttentionBlock(512)
        self.no_32 = AttentionBlock(512)

        self.mo_256 = AttentionBlock(128)
        self.mo_128 = AttentionBlock(256)
        self.mo_64 = AttentionBlock(512)
        self.mo_32 = AttentionBlock(512)

        #norm
        self.VggExtract = VGGFeat()
        
        ######################
        self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1])  #

        self.up0 = StyledUpBlock(ngf*8,ngf*8)
        self.up1 = StyledUpBlock(ngf*8, ngf*4) #
        self.up2 = StyledUpBlock(ngf*4, ngf*2) #
        self.up3 = StyledUpBlock(ngf*2, ngf) #
        self.up4 = nn.Sequential( # 128
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
            # nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            UpResBlock(ngf),
            UpResBlock(ngf),
            # SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
            nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        self.to_rgb0 = ToRGB(ngf*8)
        self.to_rgb1 = ToRGB(ngf*4)
        self.to_rgb2 = ToRGB(ngf*2)
        self.to_rgb3 = ToRGB(ngf*1)

        # for param in self.BlurInputConv.parameters():
        #     param.requires_grad = False
    
    def forward(self, input, part_locations):
        VggFeatures = self.VggExtract(input)
        # for b in range(input.size(0)):
        b = 0
        UpdateVggFeatures = []
        for i, f_size in enumerate(self.feature_sizes):
            cur_feature = VggFeatures[i]

            update_feature = cur_feature.clone() #* 0
            cur_part_sizes = self.part_sizes // (512/f_size)
            dicts_feature = getattr(self, 'Dict_'+str(f_size))
            
            LE_Dict_feature = dicts_feature['left_eye'].to(input)
            RE_Dict_feature = dicts_feature['right_eye'].to(input)
            NO_Dict_feature = dicts_feature['nose'].to(input)
            MO_Dict_feature = dicts_feature['mouth'].to(input)

            le_location = (part_locations[0][b] // (512/f_size)).int()
            re_location = (part_locations[1][b] // (512/f_size)).int()
            no_location = (part_locations[2][b] // (512/f_size)).int()
            mo_location = (part_locations[3][b] // (512/f_size)).int()

            LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone()
            RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone()
            NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone()
            MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone()
            
            #resize
            LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
            
            LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize)
            RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize)
            NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize)
            MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize)
            
            LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm)

            LE_score = F.softmax(LE_score.view(-1),dim=0)
            LE_index = torch.argmax(LE_score)
            LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3)))

            LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature)
            LE_Att_feature = LE_Attention * LE_Swap_feature
            

            RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm)
            RE_score = F.softmax(RE_score.view(-1),dim=0)
            RE_index = torch.argmax(RE_score)
            RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3)))
            
            RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature)
            RE_Att_feature = RE_Attention * RE_Swap_feature

            NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm)
            NO_score = F.softmax(NO_score.view(-1),dim=0)
            NO_index = torch.argmax(NO_score)
            NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3)))
            
            NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature)
            NO_Att_feature = NO_Attention * NO_Swap_feature

            
            MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm)
            MO_score = F.softmax(MO_score.view(-1),dim=0)
            MO_index = torch.argmax(MO_score)
            MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3)))
            
            MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature)
            MO_Att_feature = MO_Attention * MO_Swap_feature

            update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature
            update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature
            update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature
            update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature

            UpdateVggFeatures.append(update_feature) 
        
        fea_vgg = self.MSDilate(VggFeatures[3])
        #new version
        fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3])
        # out1 = F.interpolate(fea_up0,(512,512))
        # out1 = self.to_rgb0(out1)

        fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) #
        # out2 = F.interpolate(fea_up1,(512,512))
        # out2 = self.to_rgb1(out2)

        fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) #
        # out3 = F.interpolate(fea_up2,(512,512))
        # out3 = self.to_rgb2(out3)

        fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) #
        # out4 = F.interpolate(fea_up3,(512,512))
        # out4 = self.to_rgb3(out4)

        output = self.up4(fea_up3) #
        
    
        return output  #+ out4 + out3 + out2 + out1
        #0 128 * 256 * 256
        #1 256 * 128 * 128
        #2 512 * 64 * 64
        #3 512 * 32 * 32




# smaller blocks
def AttentionBlock(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))
    )


class MSDilateBlock(pl.LightningModule):
    def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
        super(MSDilateBlock, self).__init__()
        self.conv1 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
        self.conv2 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
        self.conv3 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
        self.conv4 =  convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
        self.convi =  SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        conv4 = self.conv4(x)
        cat  = torch.cat([conv1, conv2, conv3, conv4], 1)
        out = self.convi(cat) + x
        return out


class VggClassNet(pl.LightningModule):
    def __init__(self, select_layer = ['0','5','10','19']):
        super(VggClassNet, self).__init__()
        self.select = select_layer
        self.vgg = models.vgg19(pretrained=True).features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features




class UpResBlock(pl.LightningModule):
    def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
        super(UpResBlock, self).__init__()
        self.Model = nn.Sequential(
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
            # norm_layer(dim),
            nn.LeakyReLU(0.2,True),
            # SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
            conv_layer(dim, dim, 3, 1, 1),
        )
    def forward(self, x):
        out = x + self.Model(x)
        return out

In [None]:
#@title functions.py
import pytorch_lightning as pl

def compute_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x



def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
    return nn.Sequential(
        SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
#         conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
#         nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
        SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
    )


class AdaptiveInstanceNorm(pl.LightningModule):
    def __init__(self, in_channel):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channel)

    def forward(self, input, style):
        style_mean, style_std = calc_mean_std_4D(style)
        out = self.norm(input)
        size = input.size()
        out = style_std.expand(size) * out + style_mean.expand(size)
        return out

class BlurFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        grad_input = F.conv2d(
            grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
        )
        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = F.conv2d(
            gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
        )
        return grad_input, None, None


class BlurFunction(Function):
    @staticmethod
    def forward(ctx, input, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)

        return grad_input, None, None

blur = BlurFunction.apply


class Blur(pl.LightningModule):
    def __init__(self, channel):
        super().__init__()

        weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
        weight = weight.view(1, 1, 3, 3)
        weight = weight / weight.sum()
        weight_flip = torch.flip(weight, [2, 3])

        self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
        self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))

    def forward(self, input):
        return blur(input, self.weight, self.weight_flip)

class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()
        return weight * sqrt(2 / fan_in)
    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)

def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)
    return module

class EqualConv2d(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)
    def forward(self, input):
        return self.conv(input)

class NoiseInjection(pl.LightningModule):
    def __init__(self, channel):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
    def forward(self, image, noise):
        return image + self.weight * noise


def ToRGB(in_channel):
    return nn.Sequential(
        SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)),
        nn.LeakyReLU(0.2),
        SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1))
    )

def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
    # assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std_4D(style_feat)
    content_mean, content_std = calc_mean_std_4D(content_feat)
    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std + style_mean

def calc_mean_std_4D(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

Restart with ``runtime > restart runtime`` once if you see ``AttributeError: module 'PIL.TiffTags' has no attribute 'IFD'``.

In [None]:
# Training
%cd /content/
dm = DFNetDataModule(training_path = '/content/DFDNet/ffhq/', train_partpath = '/content/DFDNet/landmarks', validation_path = '/content/validation/', val_partpath='/content/landmarks', batch_size=1)
model = CustomTrainClass()
#model = model.load_from_checkpoint('/content/Checkpoint_0_450.ckpt') # start training from checkpoint, warning: apperantly global_step will be reset to zero and overwriting validation images, you could manually make an offset

#weights_init(model, 'kaiming')
# GPU
#trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=1, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=10, save_path='/content/')])
# GPU with AMP (amp_level='O1' = mixed precision)
# currently not working
#trainer = pl.Trainer(gpus=1, precision=16, amp_level='O1', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# TPU
#trainer = pl.Trainer(tpu_cores=8, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
trainer.fit(model, dm)

# Landmark generation

Install that and restart runtime.

In [None]:
!pip install face-alignment
!pip install matplotlib --upgrade

In [None]:
%cd /content/
import face_alignment
from skimage import io
import numpy as np
import glob
from tqdm import tqdm
import os
import shutil

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)

unchecked_input_path = '/content/validation' #@param {type:"string"}
checked_output_path = '/content/validation' #@param {type:"string"}
failed_output_path = '/content/validation' #@param {type:"string"}
landmark_output_path = '/content/landmarks' #@param {type:"string"}

if not os.path.exists(unchecked_input_path):
    os.makedirs(unchecked_input_path)
if not os.path.exists(checked_output_path):
    os.makedirs(checked_output_path)
if not os.path.exists(failed_output_path):
    os.makedirs(failed_output_path)
if not os.path.exists(landmark_output_path):
    os.makedirs(landmark_output_path)

files = glob.glob(unchecked_input_path + '/**/*.png', recursive=True)
files_jpg = glob.glob(unchecked_input_path + '/**/*.jpg', recursive=True)
files.extend(files_jpg)
err_files=[]

for f in tqdm(files):
  input = io.imread(f)
  preds = fa.get_landmarks(input)
  #print(preds)
  if preds is not None:
    np.savetxt(os.path.join(landmark_output_path, os.path.basename(f)+".txt"), preds[0], delimiter=' ', fmt='%1.3f')   # X is an array
    shutil.move(f, os.path.join(checked_output_path,os.path.basename(f)))
  else:
    shutil.move(f, os.path.join(failed_output_path,os.path.basename(f)))