In [1]:
# General imports

from argparse import Namespace
import time
import sys
import pprint
import numpy as np
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import utils
import matplotlib.pyplot as plt
import cv2
import random
import glob
from tqdm import tqdm
import PIL
import PIL.Image
%matplotlib inline

# Importing pSp related

pSp_root = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath('__file__')), 'pixel2style2pixel'))
sys.path.insert(0, pSp_root)

from pixel2style2pixel.datasets import augmentations
from pixel2style2pixel.utils.common import tensor2im, log_input_image
from pixel2style2pixel.models.psp import pSp
from pixel2style2pixel.models.stylegan2.model import Generator # Importing stylegan2 model from pSp repo, same thing eventually

In [2]:
# Define the pSp encoder - StyleGANv2 decoder and load pretrained weights

device = 'cuda'

ENDECODER_ARGS = {
    "pSp_model_path": "pretrained_models/psp_ffhq_encode.pt",
    "StyleGANv2_model_path": "pretrained_models/stylegan2-ffhq-config-f.pt",
    "transform": transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}

ckpt = torch.load(ENDECODER_ARGS['pSp_model_path'], map_location='cpu')
pSp_opts = ckpt['opts']

pSp_opts['checkpoint_path'] = ENDECODER_ARGS['pSp_model_path']
pSp_opts['stylegan_weights'] = ENDECODER_ARGS['StyleGANv2_model_path']
pSp_opts['learn_in_w'] = False
pSp_opts['output_size'] = 1024

pSp_net = pSp(Namespace(**pSp_opts))
pSp_net.eval()
pSp_net.cuda();

Loading pSp from checkpoint: pretrained_models/psp_ffhq_encode.pt


In [3]:
# Define the function to normalize images and convert to numpy

# This normalization block is taken from the original torch repository:
# https://github.com/pytorch/vision/blob/89d2b38cbc3254ed7ed7b43393e4635979ac12eb/torchvision/utils.py

def norm_ip(img, low, high):
    img.clamp_(min=low, max=high)
    img.sub_(low).div_(max(high - low, 1e-5))

def norm_range(t, value_range):
    if value_range is not None:
        norm_ip(t, value_range[0], value_range[1])
    else:
        norm_ip(t, float(t.min()), float(t.max()))

def normalize_image_and_convert_to_numpy(image):
    norm_range(image, (-1, 1))
    return image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()

pool_logit = torch.nn.AdaptiveAvgPool2d((256, 256))

In [4]:
# Define a function that generates a random image

def generate_random_image(downscale=False):
    z = torch.randn(1, 512, device=device) # args.sample, args.latent
    gen_img, latent = pSp_net.decoder(
        [z], truncation=1, truncation_latent=None, return_latents=True # args.truncation
    )
    if downscale:
        gen_img = pool_logit(gen_img)
    return Image.fromarray(normalize_image_and_convert_to_numpy(gen_img[0])), latent # returns tensor

# Define a function that generates an image for the given code

def generate_image_given_code(code, downscale=False):
    with torch.no_grad():
        gen_img, latent = pSp_net.decoder(
            [code], truncation=1, truncation_latent=None, return_latents=True, input_is_latent=True, randomize_noise=False # args.truncation
        )
        if downscale:
            gen_img = pool_logit(gen_img)
        return Image.fromarray(normalize_image_and_convert_to_numpy(gen_img[0])), latent # returns tensor

# Define a function that encodes a given image and returns the code alongside its decoding (its 'fake' recreation)

def encode_given_image_return_code_and_recreation(image, downscale=False):
    img_transforms = ENDECODER_ARGS['transform']
    transformed_image = img_transforms(image)
    latent = pSp_net.encoder(transformed_image.unsqueeze(0).to(device))
    latent = latent + pSp_net.latent_avg.repeat(latent.shape[0], 1)

    image, latent = generate_image_given_code(latent, downscale)
    return image, latent

# Define a function that generates an image or encodes a given image and returns image-code pair

def generate_an_image_code_pair(image=None, downscale=False):
    if image is None:
        return generate_random_image(downscale)
    else:
        return encode_given_image_return_code_and_recreation(image, downscale)
    
# Define a function to load an image from its path, optionally aligns it and returns loadedimage-generatedimage-code

def load_image_and_encode(path):
    image = PIL.Image.open(path)
    gen_image, gen_code = encode_given_image_return_code_and_recreation(image)
    return image, gen_image, gen_code

In [5]:
# Define a function to neutralize latent code

def neutralize_latent_code(code, neutral_dir, neutral_strength=20):
    code, neutral_dir = code.detach().cpu().flatten(), neutral_dir.detach().cpu().flatten()
    distance = np.dot(neutral_dir, code) / np.linalg.norm(neutral_dir)
    direction = neutral_dir / np.linalg.norm(neutral_dir)
    neutral_code = code - distance * direction
    neutral_code = neutral_code + neutral_strength * direction
    return neutral_code.reshape(18, 512).unsqueeze(0).cuda()

# Define a function to transfer emotion from code A to B

def transfer_emotion_on_code(code_A, code_B, neutral_dir, neutral_strength):
    code_A_neu = neutralize_latent_code(code_A, neutral_dir, neutral_strength)
    code_B_neu = neutralize_latent_code(code_B, neutral_dir, neutral_strength)

    return (code_A - code_A_neu) + code_B_neu

# Define a function to transfer emotion from image A to B

def transfer_emotion_on_image(image_A, image_B, neutral_dir, neutral_strength):
    code_A = encode_given_image_return_code_and_recreation(image_A)[1]
    code_B = encode_given_image_return_code_and_recreation(image_B)[1]
    
    image_A_neu, code_A_neu = generate_image_given_code(neutralize_latent_code(code_A, neutral_dir, neutral_strength))
    image_B_neu, code_B_neu = generate_image_given_code(neutralize_latent_code(code_B, neutral_dir, neutral_strength))
    
    image_A_neu_inv, code_A_neu_inv = encode_given_image_return_code_and_recreation(image_A_neu) 
    image_B_neu_inv, code_B_neu_inv = encode_given_image_return_code_and_recreation(image_B_neu)

    transfer_code = (code_A - code_A_neu_inv) + code_B_neu_inv

    return generate_image_given_code(transfer_code)[0]


# Define a function to transfer emotion from image A to B that utilizes pre-computed delta

def transfer_emotion_on_image_using_delta(delta, code_B_neu_inv):
    return generate_image_given_code(delta + code_B_neu_inv)[0]

In [6]:
# Load up directions

anger_dir = torch.from_numpy(np.load('main_directions/0.npy').astype(np.float32)).to(device)
contempt_dir = torch.from_numpy(np.load('main_directions/1.npy').astype(np.float32)).to(device)
disgust_dir = torch.from_numpy(np.load('main_directions/2.npy').astype(np.float32)).to(device)
fear_dir = torch.from_numpy(np.load('main_directions/3.npy').astype(np.float32)).to(device)
happiness_dir = torch.from_numpy(np.load('main_directions/4.npy').astype(np.float32)).to(device)
neutral_dir = torch.from_numpy(np.load('main_directions/5.npy').astype(np.float32)).to(device)
sadness_dir = torch.from_numpy(np.load('main_directions/6.npy').astype(np.float32)).to(device)
surprise_dir = torch.from_numpy(np.load('main_directions/7.npy').astype(np.float32)).to(device)

In [7]:
# Generate basis images for each emotion

# image_src, gen_image_src, gen_code_src = load_image_and_encode('basis.jpg')

# image_anger = generate_image_given_code(gen_code_src + anger_dir * 100)[0]
# image_contempt = generate_image_given_code(gen_code_src + contempt_dir * 100)[0]
# image_disgust = generate_image_given_code(gen_code_src + disgust_dir * 100)[0]
# image_fear = generate_image_given_code(gen_code_src + fear_dir * 300)[0]
# image_happiness = generate_image_given_code(gen_code_src + happiness_dir * 50)[0]
# image_neutral = generate_image_given_code(gen_code_src + neutral_dir * 50)[0]
# image_sadness = generate_image_given_code(gen_code_src + sadness_dir * 200)[0]
# image_surprise = generate_image_given_code(gen_code_src + surprise_dir * 200)[0]

# image_anger.save('quant_experiments/basis/anger_src.jpg', 'JPEG')
# image_contempt.save('quant_experiments/basis/contempt_src.jpg', 'JPEG')
# image_disgust.save('quant_experiments/basis/disgust_src.jpg', 'JPEG')
# image_fear.save('quant_experiments/basis/fear_src.jpg', 'JPEG')
# image_happiness.save('quant_experiments/basis/happiness_src.jpg', 'JPEG')
# image_neutral.save('quant_experiments/basis/neutral_src.jpg', 'JPEG')
# image_sadness.save('quant_experiments/basis/sadness_src.jpg', 'JPEG')
# image_surprise.save('quant_experiments/basis/surprise_src.jpg', 'JPEG')

In [8]:
# Load up basis images and precompute deltas, then save them

import pickle
emotion_deltas_path = 'emotion_deltas_psp.pkl'
emotions = ['anger', 'contempt', 'disgust', 'fear', 'happiness', 'neutral', 'sadness', 'surprise']

# emotion_deltas = dict()

# for emotion in emotions:
#     _, _, gen_code = load_image_and_encode(f'quant_experiments/basis_aligned/{emotion}_src.jpg')
#     image_neu, gen_code_neu = generate_image_given_code(neutralize_latent_code(gen_code, neutral_dir, 20))
#     _, gen_code_neu_inv = encode_given_image_return_code_and_recreation(image_neu)
#     emotion_deltas[emotion] = (gen_code - gen_code_neu_inv)

# with open(emotion_deltas_path, 'wb') as file:
#     pickle.dump(emotion_deltas, file)

# Load precomputed deltas

with open(emotion_deltas_path, 'rb') as file:
    emotion_deltas = pickle.load(file)

In [40]:
input_dir = 'datasets/celeba_hq_aligned'
output_dir_base = 'quant_experiments/transfer'

for root, _, files in os.walk(input_dir):
    for file in files:
        if file.lower().endswith('.jpg'):
            file_path = os.path.join(root, file)
            _, _, code = load_image_and_encode(file_path)
            
            image_neu, _ = generate_image_given_code(neutralize_latent_code(code, neutral_dir, 20))
            _, code_neu_inv = encode_given_image_return_code_and_recreation(image_neu)
            for emotion in emotions:
                output_dir = os.path.join(output_dir_base, emotion)
                relative_path = os.path.relpath(root, input_dir)
                output_file_dir = os.path.join(output_dir, relative_path)
                if not os.path.exists(output_file_dir):
                    os.makedirs(output_file_dir)
                output_file_path = os.path.join(output_file_dir, file)
                
                image_transfer = transfer_emotion_on_image_using_delta(emotion_deltas[emotion], code_neu_inv)

                image_transfer.save(output_file_path)

In [9]:
input_dir = 'datasets/celeba_hq_aligned'
output_dir_base = 'quant_experiments/transferll'
emotions = ['default']

for root, _, files in os.walk(input_dir):
    for file in files:
        if file.lower().endswith('.jpg'):
            file_path = os.path.join(root, file)
            _, ggen_img, code = load_image_and_encode(file_path)
            for emotion in emotions:
                output_dir = os.path.join(output_dir_base, emotion)
                relative_path = os.path.relpath(root, input_dir)
                output_file_dir = os.path.join(output_dir, relative_path)
                if not os.path.exists(output_file_dir):
                    os.makedirs(output_file_dir)
                output_file_path = os.path.join(output_file_dir, file)

                ggen_img.save(output_file_path)