<a href="https://colab.research.google.com/github/zaidbhat1234/StyleGAN2-ADA/blob/main/StyleGAN2_ada.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Mounting your google drive containing the code files.


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/My\ Drive/

##The code should be arranged in this order of directories

In [None]:
%cd KAUST_Internship/stylegan2-ada/stylegan2-pytorch/

##Resolving dependencies and importing libraries

In [None]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 
import argparse
import torch
from torchvision import utils
from model import Generator
from tqdm import tqdm
import lpips
import math
import torch
import torch.optim as optim
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import torchvision
from torchvision import models
from torchvision.utils import save_image
import numpy as np
from math import log10
import matplotlib.pyplot as plt

##Setup for running on GPU

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

##Function to read some images in a loop for testing, where 'i' is the index of the image.

In [None]:
def read_img(i):
  img_path = '/content/gdrive/My Drive/KAUST_Internship/StyleGAN_LatentEditor/images/'+i+'.png'
  with open(img_path,"rb") as f: 
    image=Image.open(f)
    image=image.convert("RGB")
  transform = transforms.Compose([transforms.ToTensor()])
  image = transform(image)
  image = image.unsqueeze(0)
  image = image.to(device)
  print(image.shape)
  return image

##Load the StyleGan2 generator from pre-trained weights

In [None]:
g_ema = Generator(1024, 512, 8)
checkpoint = torch.load('ffhq2.pt')
g_ema.load_state_dict(checkpoint["g_ema"])
g_ema.eval()
g_ema = g_ema.to(device)

##Get the mean latent code from the SG2 network and generate the corresponding image.

In [None]:
mean_latent = g_ema.mean_latent(4096)
print(mean_latent.shape)
img,_ = g_ema([mean_latent])
img = (img +1.0)/2.0
save_image(img.clamp(0,1),"outputs/mean_latent.png")

##Randomly generate images corresponding to random latent codes 'w'

In [None]:
"""
for i in range(20):
  z = torch.randn(1,512,device = device)
  
  img,_ = g_ema([z])
  img = (img +1.0)/2.0
  save_image(img.clamp(0,1),"outputs/random_SG2-{}.png".format(i+1))"""

In [None]:
"""n_mean_latent = 10000
noise_sample = torch.randn(n_mean_latent, 512, device=device)
latent_out = g_ema.style(noise_sample)

latent_mean = latent_out.mean(0)
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(1, 1)
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
print(np.shape(latent_in))
latents = torch.zeros((1,18,512))
img_gen, _ = g_ema([latents], input_is_latent=True)
img_gen = (img_gen+1.0)/2.0
save_image(img_gen,"gen.png")"""

##VGG Perceptual loss network to give feature vectors from 4 parts of the pre-trained VGG-16 from 2,4,14,21

In [None]:
class VGG16_perceptual(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(VGG16_perceptual, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 4):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 14):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(14, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_1 = h
        h = self.slice2(h)
        h_relu1_2 = h
        h = self.slice3(h)
        h_relu3_2 = h
        h = self.slice4(h)
        h_relu4_2 = h
        return h_relu1_1, h_relu1_2, h_relu3_2, h_relu4_2

##Loss function to calculate MSE and Perceptual losses

In [None]:
def loss_function(syn_img, img, img_p, MSE_loss, upsample, perceptual):

  #UpSample synthesized image to match the input size of VGG-16 input. 
  #Extract mid level features for real and synthesized image and find the MSE loss between them for perceptual loss. 
  #Find MSE loss between the real and synthesized images of actual size
  syn_img_p = upsample(syn_img)
  syn0, syn1, syn2, syn3 = perceptual(syn_img_p)
  r0, r1, r2, r3 = perceptual(img_p)
  mse = MSE_loss(syn_img,img)

  per_loss = 0
  per_loss += MSE_loss(syn0,r0)
  per_loss += MSE_loss(syn1,r1)
  per_loss += MSE_loss(syn2,r2)
  per_loss += MSE_loss(syn3,r3)

  return mse, per_loss

##Calculate PSNR

In [None]:
def PSNR(mse, flag = 0):
  #flag = 0 if a single image is used and 1 if loss for a batch of images is to be calculated

  if flag == 0:
    psnr = 10 * log10(1 / mse.item())
  
  return psnr
psnr_total = []

##Noise Regulariser

In [None]:
def noise_regularize(noises):
    loss = 0
    for noise in noises:
        size = noise.shape[2]
        while True:
            loss = (
                loss
                + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
                + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
            )
            if size <= 8:
                break
            noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
            noise = noise.mean([3, 5])
            size //= 2
    return loss

In [None]:
def latent_noise(latent, strength):
    noise = torch.randn_like(latent) * strength
    return latent + noise

In [None]:
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)

    return initial_lr * lr_ramp

##Initial value of Noise to be optimised

In [None]:
def init_noise():
  noises_single = g_ema.make_noise()
  noises = []
  for noise in noises_single:
    noises.append(noise.repeat(1, 1, 1, 1).normal_())
  for noise in noises:
    noise.requires_grad = True
  print(len(noises))
  return noises

In [None]:
latent_out = g_ema.style(mean_latent)
latent_mean = latent_out.mean(0)
latent_std = ((latent_out - latent_mean).pow(2).sum() / 10000) ** 0.5
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(1, 1)
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
latent_in.requires_grad = True
latents = latent_in
latents.requires_grad = True

##Initialising latent vector W+ from mean latent W

In [None]:
#Mean latent w+
def initialise():
  latent_out = g_ema.style(mean_latent)
  latent_mean = latent_out.mean(0)
  latent_std = ((latent_out - latent_mean).pow(2).sum() / 10000) ** 0.5
  latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(1, 1)
  latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
  latent_in.requires_grad = True
  latents = latent_in
  latents.requires_grad = True
  #print((latent_in))
  return latents

##Function to Penalise successive W+ codes to be similar

In [None]:
def penalise_w(latent):
  loss = 0
  MSE_loss = nn.MSELoss(reduction="mean")
  for i in range(17):
    loss += MSE_loss(latent[:,i+1,:],latent[:,i,:])
  #loss = loss/18
  #print(loss)
  return loss

##Computer Average W for hacky smoothing

In [None]:
def average_w(latent1):
  #print(latent[:,1,:].shape)
  #Check if this is correct
  #latent1.requires_grad = False
  mean = torch.mean(latent1, dim=1)
  lambdaa = 0.2
  #print(mean.shape)
  latent = latent1
  for i in range(17):
    latent[:,i,:] = lambdaa * latent1[:,i,:] + (1.0-lambdaa) * mean
  #print(latent.shape)
  #latent1.requires_grad = True
  return latent
#average_w(latents)
"""
 if (e+1)%50 ==0:
      #Hacky Smoothing
      print("hacky smoothing")
      latents.requires_grad = False
      latents = average_w(latents)
      latents.requires_grad = True"""

###Masking the image for an experiment which required calculating loss on the centre cropped image rather than the downsampled image.

In [None]:
mask = torch.ones(256,256)
zero = torch.zeros(1024,1024)
zero[384:640,384:640] = mask 
mask = zero
mask = mask.unsqueeze(0)
mask = torch.cat(3*[mask])
mask = mask.unsqueeze(0)
print(mask.shape)
mask = mask.to(device)

##Embedding Function to optimise the latent code W+ for GAN inversion.

In [None]:
def embedding_function(image,latent,x):
  upsample = torch.nn.Upsample(scale_factor = 256/1024, mode = 'bilinear')
  img_p = image.clone()
  img_p = upsample(img_p) #Downsample image to 256x256 for 
  img_c = image.clone()
  img_c = img_c * mask #Centre crop image with mask
  img_c_p = upsample(img_c)

  #Initialise VGG-perceptual loss
  perceptual1 = VGG16_perceptual().to(device)
  #Initialise LPIPS-perceptual loss
  #perceptual = lpips.PerceptualLoss(model="net-lin", net="vgg")
  
  #MSE loss object
  MSE_loss = nn.MSELoss(reduction="mean")

  #latents = torch.zeros((1,18,512), requires_grad = True, device = device)
  latents = latent
  latents.requires_grad = True
  latents.to(device)
  
  #Optimizer to change latent code in each backward step
  optimizer = optim.Adam({latents},lr=0.01) #,betas=(0.9,0.999),eps=1e-8 [latents]+noises


  #Loop to optimise latent vector to match the generated image to input image
  loss_ = []
  loss_psnr = []
  final_psnr = 0
  for e in range(1500):
    optimizer.zero_grad()
    syn_img,_ = g_ema([latents], input_is_latent=True)
    syn_img = (syn_img+1.0)/2.0
    mse, per_loss = loss_function(syn_img, image, img_p, MSE_loss, upsample, perceptual1)
    psnr = PSNR(mse, flag = 0)

    #n_loss = noise_regularize(noises)
    #per_loss = perceptual(syn_img, image).sum()
    #loss_w = penalise_w(latents) 

    # To use centre crop and downsampled for LPIPS
    #syn_img_p = syn_img.clone()
    #syn_img_p = upsample(syn_img)
    #syn_img_c = syn_img.clone()
    #syn_img_c = syn_img_c * mask
    #print(syn_img_p.shape,syn_img_c.shape)
    #per_loss_down = perceptual(syn_img_p, img_p).sum()
    #per_loss_crop = perceptual(syn_img_c, img_c).sum()
    #loss_pd = per_loss_down.detach().cpu().numpy()
    #loss_cd = per_loss_crop.detach().cpu().numpy()
    #mse1, per_loss1 = loss_function(syn_img_c,  img_c, img_c_p, MSE_loss, upsample, perceptual1)
    
    #loss is sum of losses of downsampled and centre cropeed image
    #per_loss = per_loss +per_loss1 
    #mse = mse+mse1

    loss = per_loss +mse 
    loss.backward()
    optimizer.step()
    loss_np=loss.detach().cpu().numpy()
    loss_p=per_loss.detach().cpu().numpy()
    loss_m=mse.detach().cpu().numpy()
    loss_psnr.append(psnr)
    loss_.append(loss_np)
    final_psnr = psnr
    
    if (e+1)%500==0:
      print("iter{}: loss -- {},  mse_loss --{},  percep_loss --{}, psnr --{}".format(e+1,loss_np,loss_m,loss_p,psnr))
      #print("iter{}: loss--{},mse_loss--{},per_loss_down--{},per_loss_crop--{},psnr--{}".format(e+1,loss_np,loss_m,loss_pd,loss_cd, psnr))
      save_image(syn_img.clamp(0,1),"outputs/Step1-2VGG-{}-{}.png".format(x,e+1)) #Save Images
      #np.save("loss_list.npy",loss_)
      #np.save("latent_W.npy".format(),latents.detach().cpu().numpy())

  plt.plot(loss_, label = 'Loss = MSELoss + Perceptual')
  plt.plot(loss_psnr, label = 'PSNR')
  plt.legend()
  return latents, final_psnr

In [None]:
latents, psnr = embedding_function(image,latents,x) #Calling the embedding function to optimise latent W. It returns the optimised latent code which can be further optimised for other experiments

##Embedding function to optimise W+ along with noise

In [None]:
def embedding_function_n(image,latent,x):
  upsample = torch.nn.Upsample(scale_factor = 256/1024, mode = 'bilinear')
  img_p = image.clone()
  img_p = upsample(img_p)
  #print(img_p.shape)

  #Initialise perceptual losses initialise object
  perceptual1 = VGG16_perceptual().to(device)
  perceptual = lpips.PerceptualLoss(model="net-lin", net="vgg")
  
  #MSE loss object
  MSE_loss = nn.MSELoss(reduction="mean")

  latents = latent
  latents.requires_grad = True
  latents.to(device)

  #Optimizer to change latent code in each backward step
  optimizer = optim.Adam([latents]+noises,lr=0.1) #,betas=(0.9,0.999),eps=1e-8 [latents]+noises

  pbar = tqdm(range(1000))

  #Loop to optimise latent vector to match the generated image to input image
  loss_ = []
  loss_psnr = []
  final_psnr = 0
  for e in pbar:
    t = e / 1000
    lr = get_lr(t, 0.1)
    optimizer.param_groups[0]["lr"] = lr
    noise_strength = latent_std * 0.05 * max(0, 1 - t / 0.75) ** 2
    latent_n = latent_noise(latents, noise_strength.item())
    
    syn_img,_ = g_ema([latent_n], input_is_latent=True, noise= noises)
    syn_img = (syn_img+1.0)/2.0
    syn_img_p = upsample(syn_img)
    mse, per_loss1 = loss_function(syn_img, image, img_p, MSE_loss, upsample, perceptual1)
    psnr = PSNR(mse, flag = 0)
    final_psnr = psnr
    n_loss = noise_regularize(noises)
    per_loss = perceptual(syn_img_p, img_p).sum()
    #loss_w = penalise_w(latents) 
    loss = per_loss + mse +1e5 * n_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_np=loss.detach().cpu().numpy()
    loss_p=per_loss.detach().cpu().numpy()
    loss_m=mse.detach().cpu().numpy()
    loss_psnr.append(psnr)
    loss_.append(loss_np)
    if (e+1)%1000==0:
      #print("penalise:", loss_w)
      print("iter{}: loss -- {},  mse_loss --{},  percep_loss --{}, psnr --{}".format(e+1,loss_np,loss_m,loss_p,psnr))
      save_image(syn_img.clamp(0,1),"outputs/Step2-LPIPS-256 scratch-{}-{}.png".format(x,e+1))
      #np.save("loss_list.npy",loss_)
      #np.save("latent_W.npy".format(),latents.detach().cpu().numpy())

  plt.plot(loss_, label = 'Loss = MSELoss + Perceptual')
  plt.plot(loss_psnr, label = 'PSNR')
  plt.legend()
  return latents, final_psnr

In [None]:
latents,psnr = embedding_function_n(image,latents)

##Loop to iterate over 10 images to find the average PSNR

In [None]:
t_psnr=[]
t_psnr_c=[]
t_psnr_s = []
for i in range(10):
  image = read_img(str(i+1))
  latents = initialise()
  noises = init_noise()
  latents, psnr = embedding_function(image, latents,i+1) #only vgg
  #latents, psnr1 = embedding_function_n(image, latents,i+1) #LPIPS continued
  #t_psnr.append(psnr)
  #t_psnr_c.append(psnr1)
  t_psnr_s.append(psnr)

In [None]:
print('average psnr vgg', t_psnr)
print('average psnr lpips continue', t_psnr_c)

In [None]:
print('average psnr 2lpips', t_psnr_s)
av1 = np.sum(t_psnr_s)/10
print('average',av1)

In [None]:
av = np.sum(t_psnr)/10
av1 = np.sum(t_psnr_c)/10
print('average',av,av1)

##Incomplete code in TensorFlow before switching to Pytorch


In [None]:
!python projector.py --ckpt ffhq2.pt --size 1024 '/content/gdrive/My Drive/KAUST_Internship/StyleGAN_LatentEditor/images/ryan_01.png'
#Official TF version
!pip install tensorflow==1.14
!pip install tensorflow-gpu==1.14
import argparse
import os
import pickle
import re

import numpy as np
import PIL.Image

import dnnlib
import dnnlib.tflib as tflib
import torch
import tensorflow as tf
from torchvision import models
tflib.init_tf()
network_pkl = 'ffhq.pkl'
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as fp:
  _G, _D, Gs = pickle.load(fp)
# load Image
img = PIL.Image.open('images/img1.png').convert('RGB')
img = np.array(img, dtype=np.uint8)
img = img.astype(np.float32).transpose([2, 0, 1]) * (2 / 255) - 1
print(img.shape)

# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
def downsample(image):
  img_sample = tf.expand_dims(img,axis=0)
  img_sample = (img_sample + 1) * (255 / 2)
  sh = img_sample.shape.as_list()
  if sh[2] > 256:
    factor = sh[2] // 256
    img_sample = tf.reduce_mean(tf.reshape(img_sample, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5])
  return img_sample

def embedding_function(image):
  img_clone = downsample(image)
  print(image.shape,img_clone.shape)
  perceptual = VGG16_perceptual()
  img_clone = tf.make_ndarray(img_clone)
  img_clone = torch.from_numpy(img_clone)
  syn0, syn1, syn2, syn3 = perceptual(img_clone)
  print(syn0)

embedding_function(img)

w = np.zeros((1, 18, *Gs.input_shape[1:]))
img = tf.cast(Gs.components.synthesis.get_output_for(w),tf.float32)
img1 = tflib.convert_images_to_uint8(img, nchw_to_nhwc=True)
print(img)

proc_images_expr = (img + 1) * (255 / 2)
sh = proc_images_expr.shape.as_list()
if sh[2] > 256:
  factor = sh[2] // 256
  proc_images_expr = tf.reduce_mean(tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5])
print(proc_images_expr.shape)

rnd = np.random.RandomState(7)
z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
z = np.zeros((1,512))
print(z.shape)
#tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
Gs_kwargs = {
        'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
        'randomize_noise': False
    }
label = np.zeros([1] + Gs.input_shapes[1][1:])
print(label.shape)
images = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]

PIL.Image.fromarray(images[0], 'RGB').save(f'out/seed{rnd}.png')

w = np.zeros((1, 18, *Gs.input_shape[1:]))
print(w.shape)
images = Gs.components.synthesis.run(w, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
PIL.Image.fromarray(images[0], 'RGB').save(f'out/w_zeros{rnd}.png')