In [None]:
import os
# os.chdir('/content')
# CODE_DIR = 'hyperstyle'
CODE_DIR = '.'

In [None]:
# !git clone https://github.com/yuval-alaluf/hyperstyle.git $CODE_DIR

In [None]:
# !wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
# !sudo unzip ninja-linux.zip -d /usr/local/bin/
# !sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

In [None]:
os.chdir(f'./{CODE_DIR}')

In [None]:
from argparse import Namespace
import time
import os
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from models.stylegan2.model import Generator
from utils.common import tensor2im
from utils.inference_utils import run_inversion
from utils.domain_adaptation_utils import run_domain_adaptation
from utils.model_utils import load_model, load_generator

%load_ext autoreload
%autoreload 2

## Step 1: Select Experiment Type
Select which experiment you wish to perform inference on:

In [None]:
#@title Select which experiment you wish to perform inference on: { run: "auto" }
experiment_type = 'ffhq_hypernet' #@param ['ffhq_hypernet', 'cars_hypernet', 'afhq_wild_hypernet']

## Step 2: Prepare to Download Pretrained Models 
As part of this repository, we provide pretrained models for each of the above experiments. Here, we'll create the download command needed for downloading the desired model.

In [None]:
def get_download_model_command(file_id, file_name):
    """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
    current_directory = os.getcwd()
    save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models")
    os.makedirs(save_path, exist_ok=True)
    command = f"gdown --id {file_id} -O {save_path}/{file_name}"
    return command    

In [None]:
MODEL_PATHS = {
    "ffhq_hypernet": {"id": "1C3dEIIH1y8w1-zQMCyx7rDF0ndswSXh4", "name": "hyperstyle_ffhq.pt"},
    "cars_hypernet": {"id": "1WZ7iNv5ENmxXFn6dzPeue1jQGNp6Nr9d", "name": "hyperstyle_cars.pt"},
    "afhq_wild_hypernet": {"id": "1OMAKYRp3T6wzGr0s3887rQK-5XHlJ2gp", "name": "hyperstyle_afhq_wild.pt"}
}
path = MODEL_PATHS[experiment_type]
hyperstyle_download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) 

In [None]:
W_ENCODERS_PATHS = {
    "ffhq_hypernet": {"id": "1M-hsL3W_cJKs77xM1mwq2e9-J0_m7rHP", "name": "faces_w_encoder.pt"},
    "cars_hypernet": {"id": "1GZke8pfXMSZM9mfT-AbP1Csyddf5fas7", "name": "cars_w_encoder.pt"},
    "afhq_wild_hypernet": {"id": "1MhEHGgkTpnTanIwuHYv46i6MJeet2Nlr", "name": "afhq_wild_w_encoder.pt"}
}
path = W_ENCODERS_PATHS[experiment_type]
w_encoder_download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) 

## Step 3: Define Inference Parameters

Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the image to perform inference on.  
While we provide default values to run this script, feel free to change as needed.

In [None]:
EXPERIMENT_DATA_ARGS = {
    "ffhq_hypernet": {
#         "model_path": "./pretrained_models/hyperstyle_ffhq.pt",
#         "w_encoder_path": "./pretrained_models/faces_w_encoder.pt",
#         "image_path": "./notebooks/images/face_image.jpg",
        "model_path": "../pretrained_models/hyperstyle_ffhq.pt",
        "w_encoder_path": "../pretrained_models/faces_w_encoder.pt",
        "image_path": "../notebooks/images/face_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    },
    "cars_hypernet": {
#         "model_path": "./pretrained_models/hyperstyle_cars.pt",
#         "w_encoder_path": "./pretrained_models/cars_w_encoder.pt",
#         "image_path": "./notebooks/images/car_image.jpg",
        "model_path": "../pretrained_models/hyperstyle_cars.pt",
        "w_encoder_path": "../pretrained_models/cars_w_encoder.pt",
        "image_path": "../notebooks/images/car_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((192, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    },
    "afhq_wild_hypernet": {
#         "model_path": "./pretrained_models/hyperstyle_afhq_wild.pt",
#         "w_encoder_path": "./pretrained_models/afhq_wild_w_encoder.pt",
#         "image_path": "./notebooks/images/afhq_wild_img.jpg",
        "model_path": "../pretrained_models/hyperstyle_afhq_wild.pt",
        "w_encoder_path": "../pretrained_models/afhq_wild_w_encoder.pt",
        "image_path": "../notebooks/images/afhq_wild_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    },
    "domain_adaptation": {  # used in a later part of the notebook, checkpoint path will be defined separately
        "image_path": "../notebooks/images/domain_adaptation.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    }
}

In [None]:
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]

To reduce the number of requests to fetch the model, we'll check if the model was previously downloaded and saved before downloading the model.  
We'll download the model for the selected experiment and save it to the folder `../pretrained_models`.

We also need to verify that the model was downloaded correctly. All of our models should weigh approximately 1.3GB.
Note that if the file weighs several KBs, you most likely encounter a "quota exceeded" error from Google Drive. In that case, you should try downloading the model again after a few hours.

In [None]:
if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
    print(f'Downloading HyperStyle model for {experiment_type}...')
    os.system(hyperstyle_download_command)
    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model
    if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
        raise ValueError("Pretrained model was unable to be downloaded correctly!")
    else:
        print('Done.')
else:
    print(f'HyperStyle model for {experiment_type} already exists!')

In addition, we need to download the WEncoder for the desired domain.

In [None]:
if not os.path.exists(EXPERIMENT_ARGS['w_encoder_path']) or os.path.getsize(EXPERIMENT_ARGS['w_encoder_path']) < 1000000:
    print(f'Downloading the WEncoder model for {experiment_type}...')
    os.system(w_encoder_download_command)
    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model
    if os.path.getsize(EXPERIMENT_ARGS['w_encoder_path']) < 1000000:
        raise ValueError("Pretrained model was unable to be downloaded correctly!")
    else:
        print('Done.')
else:
    print(f'WEncoder model for {experiment_type} already exists!')

## Step 4: Load Pretrained Model
We assume that you have downloaded all relevant models and placed them in the directory defined by the above dictionary.

In [None]:
model_path = EXPERIMENT_ARGS['model_path']
net, opts = load_model(model_path, update_opts={"w_encoder_checkpoint_path": EXPERIMENT_ARGS['w_encoder_path']})
print('Model successfully loaded!')
pprint.pprint(vars(opts))

## Step 5: Visualize Input

In [None]:
image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"]
original_image = Image.open(image_path).convert("RGB")

In [None]:
if experiment_type == 'cars_encode':
    original_image = original_image.resize((192, 256))
else:
    original_image = original_image.resize((256, 256))

In [None]:
original_image

### Align Image

Note: in this notebook we'll run alignment on the input image when working on the human facial domain.

In [None]:
def run_alignment(image_path):
    import dlib
    from scripts.align_faces_parallel import align_face
    if not os.path.exists("shape_predictor_68_face_landmarks.dat"):
        print('Downloading files for aligning face image...')
        os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
        os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')
        print('Done.')
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    aligned_image = align_face(filepath=image_path, predictor=predictor) 
    print("Aligned image has shape: {}".format(aligned_image.size))
    return aligned_image 

In [None]:
input_is_aligned = False
if experiment_type == "ffhq_hypernet" and not input_is_aligned:
    input_image = run_alignment(image_path)
else:
    input_image = original_image

input_image.resize((256, 256))

## Step 6: Perform Inference

In [None]:
img_transforms = EXPERIMENT_ARGS['transform']
transformed_image = img_transforms(input_image)

Now we'll run inference. By default, we'll run using 5 inference steps. You can change the parameter in the cell below.

In [None]:
opts.n_iters_per_batch = 5
opts.resize_outputs = False  # generate outputs at full resolution

In [None]:
with torch.no_grad():
    tic = time.time()
    result_batch, result_latents, _ = run_inversion(transformed_image.unsqueeze(0).cuda(), 
                                                    net, 
                                                    opts,
                                                    return_intermediate_results=True)
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))

### Visualize Result

We'll visualize the step-by-step outputs side by side.

In [None]:
if opts.dataset_type == "cars_encode":
    resize_amount = (256, 192) if opts.resize_outputs else (512, 384)
else:
    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)

In [None]:
def get_coupled_results(result_batch, transformed_image):
    result_tensors = result_batch[0]  # there's one image in our batch
    final_rec = tensor2im(result_tensors[-1]).resize(resize_amount)
    input_im = tensor2im(transformed_image).resize(resize_amount)
    res = np.concatenate([np.array(input_im), np.array(final_rec)], axis=1)
    res = Image.fromarray(res)
    return res

Note that the step-by-step outputs are shown left-to-right with the original input on the right-hand side.

In [None]:
res = get_coupled_results(result_batch, transformed_image)
res

In [None]:
# save image 
outputs_path = "./outputs"
os.makedirs(outputs_path, exist_ok=True)
res.save(os.path.join(outputs_path, os.path.basename(image_path)))

# Domain Adaptation

In the paper, we show that the weight offsets predicted by HyperStyle over the FFHQ domain are also applicable on fine-tuned generators such as toonify and StyleGAN-NADA.


We demonstrate this idea below.

In [None]:
generator_type = 'toonify' #@param ['toonify', 'pixar']

In [None]:
# download fine-tuned generator
FINETUNED_MODELS = {
    "toonify": {'id': '1r3XVCt_WYUKFZFxhNH-xO2dTtF6B5szu', 'name': 'toonify.pt'},
    "pixar": {'id': '1trPW-To9L63x5gaXrbAIPkOU0q9f_h05', 'name': 'pixar.pt'},
    "sketch": {'id': '1aHhzmxT7eD90txAN93zCl8o9CUVbMFnD', 'name': 'sketch.pt'},
    "disney_princess": {'id': '1rXHZu4Vd0l_KCiCxGbwL9Xtka7n3S2NB', 'name': 'disney_princess.pt'}
}
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS['domain_adaptation']
path = FINETUNED_MODELS[generator_type]

# generator_path = os.path.join("./pretrained_models", path['name'])
generator_path = os.path.join("../pretrained_models", path['name'])

if not os.path.exists(generator_path):
    print(f'Downloading fine-tuned {generator_type} generator...')
    download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
    os.system(download_command)
    print('Done.')
else:
    print(f'Fine-tuned {generator_type} generator already exists!')

In [None]:
# load model
fine_tuned_generator = load_generator(generator_path)
print(f'Fine-tuned {generator_type} generator successfully loaded!')

In [None]:
# load ReStyle e4e:
RESTYLE_E4E_MODELS = {'id': '1e2oXVeBPXMQoUoC_4TNwAWpOPpSEhE_e', 'name': 'restlye_e4e.pt'}

# restyle_e4e_path = os.path.join("./pretrained_models", RESTYLE_E4E_MODELS['name'])
restyle_e4e_path = os.path.join("../pretrained_models", RESTYLE_E4E_MODELS['name'])

if not os.path.exists(restyle_e4e_path):
    print('Downloading ReStyle-e4e model...')
    download_command = get_download_model_command(file_id=RESTYLE_E4E_MODELS["id"], file_name=RESTYLE_E4E_MODELS["name"])
    os.system(download_command)
    print('Done.')
else:
    print('ReStyle-e4e model already exists!')

In [None]:
# load restyle-e4e model
restyle_e4e, restyle_e4e_opts = load_model(restyle_e4e_path, is_restyle_encoder=True)
print(f'ReStyle-e4e model successfully loaded!')

In [None]:
# load image. Note that uploaded images must be aligned first, example image is already aligned.
image_path = EXPERIMENT_DATA_ARGS['domain_adaptation']["image_path"]
input_is_aligned = True
if not input_is_aligned:
    input_image = run_alignment(image_path)
else:
    input_image = Image.open(image_path).convert("RGB")

input_image.resize((256, 256))

In [None]:
# transform image
img_transforms = EXPERIMENT_ARGS['transform']
transformed_image = img_transforms(input_image)

In [None]:
restyle_e4e_opts.n_iters_per_batch = 5
restyle_e4e_opts.resize_outputs = False
opts.n_iters_per_batch = 5
opts.resize_outputs = False  # generate outputs at full resolution

In [None]:
with torch.no_grad():
    tic = time.time()
    result, _ = run_domain_adaptation(transformed_image.unsqueeze(0).cuda(), 
                                      net, 
                                      opts, 
                                      fine_tuned_generator, 
                                      restyle_e4e, 
                                      restyle_e4e_opts)
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))

In [None]:
final_res = tensor2im(result[0]).resize(resize_amount)
input_im = tensor2im(transformed_image).resize(resize_amount)
res = np.concatenate([np.array(input_im), np.array(final_res)], axis=1)
res = Image.fromarray(res)
res

In [None]:
# save image 
outputs_path = f"./outputs/domain_adaptation/{generator_type}"
os.makedirs(outputs_path, exist_ok=True)
res.save(os.path.join(outputs_path, os.path.basename(image_path)))