In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [None]:
!git clone https://github.com/guillermogotre/cusp-pytorch/
%cd cusp-pytorch

Cloning into 'cusp-pytorch'...
remote: Enumerating objects: 185, done.[K
remote: Total 185 (delta 0), reused 0 (delta 0), pack-reused 185[K
Receiving objects: 100% (185/185), 9.44 MiB | 16.33 MiB/s, done.
Resolving deltas: 100% (83/83), done.
/content/cusp-pytorch


In [None]:
import os
import time
import pickle

import torch
import torch.nn.functional as F

import PIL.Image
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Custom modules
from training.networks import VGG, module_no_grad
import legacy
from torch_utils import misc
import dnnlib

# GDrive authentication and Download
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

In [None]:
import os
dir_list = []
for dir in os.listdir('/content/gdrive/MyDrive/FGNET73/train'):
  dir_list.append(dir)

In [None]:
print(len(dir_list))

82


In [None]:
def Average(lst):
    return sum(lst) / len(lst)

In [None]:
FFHQ_RR_KEY = "hrfae" # Model trained on HRFAE dataset

# Choose one from above
KEY = FFHQ_RR_KEY 

# Config and GDrive ID
configs = {
    FFHQ_RR_KEY: dict(
        side=224, 
        classes=(20,60))
}

# CUDA device
device = torch.device('cuda',0)

# Side of input images
side = configs[KEY]['side']

# Weights paths
weights_path = "/content/gdrive/MyDrive/CUSP_implement/network-snapshot-002408.pkl"
vgg_path = "/content/gdrive/MyDrive/CUSP_implement/dex_imdb_wiki.caffemodel.pt"

data_dir = '/content/gdrive/MyDrive/FGNET64/train'

In [None]:
def load_model(model_path,vgg_path,device):        
    with open(model_path,'rb') as f:
        contents = legacy.load_network_pkl(f) # Pickles weights and source code
    
    # Get exponential movign average model
    G_ema = contents['G_ema']
    
    # Load DEX VGG classifier
    vgg = VGG()
    vgg_state_dict = torch.load(vgg_path)
    vgg_state_dict = {k.replace('-', '_'): v for k, v in vgg_state_dict.items()}
    vgg.load_state_dict(vgg_state_dict)
    module_no_grad(vgg) #!important

    # Set classifier
    G_ema.skip_grad_blur.model.classifier = vgg        
    # No grad
    G_ema = G_ema.to(device).eval().requires_grad_(False)
    # No grad on VGG

    return G_ema

G_ema = load_model(weights_path, vgg_path, device)

In [None]:
def run_model(G, img, label, global_blur_val=None, mask_blur_val=None, return_msk = False):
    # Tranform label to One Hot Encoding
    cls = torch.nn.functional.one_hot(
        torch.tensor(label), 
        num_classes=G.attr_map.fc0.init_args[0]
    ).to(img.device)
    
    # Content encoder
    _,c_out_skip = G.content_enc(img)
    
    # Style encodder
    s_out = G.style_enc(img)[0].mean((2, 3))
    
    truncation_psi=1
    truncation_cutoff=None
    s_out = G.style_map(s_out, None, truncation_psi, truncation_cutoff)
    
    # age mapping
    a_out = G.attr_map(cls.to(s_out.device), None, truncation_psi, truncation_cutoff)

    # Style mapping and Age mapping are interleaved for the corresponding 
    # weight demodulation modules
    w = G.__interleave_attr_style__(a_out, s_out)

    # Global blur
    for i,(f,_) in enumerate(zip(G.skip_transf, c_out_skip)):
        if f is not None:
            c_out_skip[i] = G._batch_blur(c_out_skip[i], blur_val = global_blur_val)
    
    # Masked blur
    cam = G.skip_grad_blur(img.float())
    msk = cam
    for i, (f, c) in enumerate(zip(G.skip_transf, c_out_skip)):
        if f is not None:
            im_size = c.size(-1)
            blur_c = G._batch_blur(c, blur_val= mask_blur_val)
            if msk.size(2) != im_size:
                msk = F.interpolate(msk,size=(im_size,im_size), mode='area')
            merged_c = c * msk + blur_c * (1 - msk)
            c_out_skip[i] = merged_c
            

    # Decoder
    img_out = G.image_dec(c_out_skip, w)

    if return_msk:
        to_return = (img_out,msk,cam) if G.learn_mask is not None else (img_out,None,None)
    else:
        to_return = img_out
    
    return to_return

In [None]:
# Transform tensor to uint8 image
def to_uint8(im_tensor):
    im_tensor = (im_tensor.detach().cpu().numpy().transpose((1,2,0))+1)*(256/2)
    im_tensor = np.clip(im_tensor,0,255).astype(np.uint8)
    return im_tensor

In [None]:
%%capture
for dir in dir_list:
  img_count = sum(len(files) for _, _, files in os.walk(os.path.join(data_dir, dir)))
  steps = 30 // img_count
  sample_images_path= os.path.join(data_dir, dir) 

  # Read image filenames
  filenames_batch = [
        os.path.join(sample_images_path,f) 
        for  f in next(iter(os.walk(sample_images_path)))[2] 
        if f[-4:] == '.JPG'
      ]

  # Read images
  imgs = [np.array(PIL.Image.open(f).resize((side,side)).convert("RGB"),dtype=np.float32).transpose((2,0,1)) for f in filenames_batch]
  # Transform to tensors
  im_in_tensor = (torch.tensor(np.array(imgs))/256*2-1).cuda() # Values {-1,1}

  # Repeat images N times
  n_images = im_in_tensor.shape[0]
  im_in_tensor_exp = im_in_tensor[:,None].expand([n_images,steps,*im_in_tensor.shape[1:]]).reshape([-1,*im_in_tensor.shape[1:]])
  # Labels range for examples generation
  data_labels_range = configs[KEY]['classes']
  # Define target ages
  labels_exp = torch.tensor(np.repeat(np.linspace(*data_labels_range,steps,dtype=int)[:,None],n_images,1).T.reshape(-1))

  
  batch_size = img_count
  # Run model
  im_out_tensor_exp = torch.concat([run_model(
      G_ema,
      mini_im,
      mini_label,
      global_blur_val=0, # CUSP global blur
      mask_blur_val=0)   # CUSP masked blur
      for mini_im, mini_label
      in zip(
          im_in_tensor_exp.split(batch_size),
          labels_exp.split(batch_size)
      )])
  # Transform to [batch_size, N_ages, W, H , C]
  im_out_tensor = im_out_tensor_exp.reshape([-1,steps,*im_out_tensor_exp.shape[1:]])

  for fname, im_in, im_out, age_labels in zip(
        filenames_batch,im_in_tensor,im_out_tensor, 
        labels_exp.numpy().reshape(-1,steps)
        ):
    age_labels = [i for i in age_labels]
    image_name = fname.split('/')[-1]
    # For every [input,step...]
    for im,l in zip(im_out,age_labels):
        im = Image.fromarray(to_uint8(im)).resize((250,250))
        saved_image_name = image_name.split('.')[0] + f'_GAN_{str(l)}'+'.JPG'
        im.save(os.path.join(sample_images_path, saved_image_name))

--------------------------------------------------------------------------------------------------------