In [None]:
from argparse import Namespace
from tqdm import tqdm
from pathlib import Path
import numpy as np
from PIL import Image

import torch

from models.psp import pSp
from utils.img_utils import preproc, postproc
from utils.load_utils import load_network_pkl, blend_models

In [None]:
dataset_path = Path("/datasets/RD/photo_data/LHQ_256")
imgs = list(dataset_path.glob("train_imgs/*.png"))
imgs += list(dataset_path.glob("test_imgs/*.png"))
imgs = sorted(imgs)

pretrain_path = "weights/stylegan2-pretrained.pkl"
finetune_path = "weights/stylegan2-finetuned.pkl"
ckpt_path = "weights/lhq-e4e-encoder.pt"

blend_lv = 8

seeds = range(0,90000,300)

mainpath = Path(f"imgs/e4e-imgs-{blend_lv}-v2")
mainpath.mkdir(parents=True, exist_ok=True)

In [None]:
with open(pretrain_path, 'rb') as f:
    G_pretrain = load_network_pkl(f)['G_ema'].to("cuda")
with open(finetune_path, 'rb') as f:
    G_finetune = load_network_pkl(f)['G_ema'].to("cuda")

G_blended = blend_models(G_pretrain, G_finetune, blend_lv)

In [None]:
opts = torch.load(ckpt_path, map_location="cpu")["opts"]
opts['checkpoint_path'] = ckpt_path
opts['device'] = "cuda"
opts['model_type'] = 'stylegan2'
net = pSp(Namespace(**opts)).eval().to("cuda")

In [None]:
img_path = [x for x in imgs if x.stem == f"{2700:07d}"][0]
img = Image.open(img_path).resize((256,256))
img_t = preproc(img).to("cuda")

y_hat, latent = None, None
results_batch = []
results_latent = []

with torch.no_grad():
    _, ws = net(img_t, randomize_noise=False, return_latents=True)
    img_pretrain = G_pretrain.synthesis(ws, noise_mode="const")
    img_finetune = G_finetune.synthesis(ws, noise_mode="const")
    img_blended = G_blended.synthesis(ws, noise_mode="const")

final_img = np.concatenate([img, postproc(img_pretrain[0]), postproc(img_finetune[0]), postproc(img_blended[0])], axis=1)
Image.fromarray(final_img)