# Child to Parents mappings in StyleGAN2 latent space using the Ridge Regression

In [2]:
from dataset.nok_mean import NokMeanDataset
from dataset.nok import NokDataset
from dataset.nok_aug import NokAugDataset
from utils.stylegan import StyleGAN2
from utils.eval import BaseEvaluator
from utils.viz import image_add_label
import os
import sys
import numpy as np
import torch, torchvision
from PIL import Image

from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import mean_squared_error


In [3]:
def load_data(split):
    X = []
    y = []
    for father, mother, child, *_ in NokMeanDataset(split=split):
        input = child.flatten(0)
        output = torch.cat([father.flatten(0), mother.flatten(0)], dim=0)
        X.append(input)
        y.append(output)
    X = torch.stack(X, dim=0)
    y = torch.stack(y, dim=0)
    return X, y

In [5]:
output_path = "/home/vidp/Documents/fri-2022-diploma/src/.tmp/stylegan2-ada-pytorch/child_to_parent_2"
os.mkdir(output_path)

In [6]:
X_train, y_train = load_data(split="train")
X_test, y_test = load_data(split="test")

Loaded 433 persons with 2676 images.
Average images per person: 6.180138568129331
Max images per person: 37
Min images per person: 1
[13, 12, 6, 1, 4, 2, 6, 9, 8, 2, 4, 7, 1, 6, 3, 3, 5, 8, 12, 4, 3, 6, 20, 4, 10, 3, 7, 3, 1, 4, 5, 9, 5, 8, 2, 4, 4, 3, 3, 2, 11, 2, 7, 14, 8, 10, 2, 1, 2, 9, 11, 5, 1, 1, 5, 2, 5, 5, 8, 1, 1, 5, 5, 3, 1, 3, 7, 23, 14, 15, 16, 5, 4, 1, 3, 9, 5, 6, 5, 1, 2, 5, 16, 4, 2, 1, 5, 2, 6, 24, 5, 15, 3, 9, 13, 2, 5, 6, 9, 7, 4, 8, 4, 7, 5, 11, 4, 9, 7, 6, 6, 7, 8, 5, 7, 5, 18, 10, 6, 8, 1, 5, 5, 9, 14, 6, 6, 1, 11, 14, 2, 7, 1, 2, 4, 4, 17, 5, 1, 6, 6, 4, 2, 3, 4, 5, 1, 2, 4, 4, 9, 3, 6, 2, 3, 5, 8, 6, 7, 14, 6, 5, 18, 9, 5, 3, 13, 3, 12, 5, 14, 5, 5, 1, 5, 9, 2, 9, 2, 6, 10, 8, 10, 12, 9, 6, 19, 10, 7, 37, 9, 6, 1, 1, 4, 24, 7, 3, 3, 4, 6, 13, 5, 8, 4, 5, 4, 2, 24, 5, 19, 6, 5, 13, 5, 3, 7, 8, 11, 4, 12, 7, 8, 7, 14, 15, 6, 8, 4, 5, 3, 5, 6, 8, 11, 6, 2, 2, 10, 4, 4, 3, 1, 1, 1, 1, 2, 1, 2, 5, 1, 16, 9, 8, 4, 4, 1, 9, 3, 5, 7, 8, 3, 2, 6, 4, 1, 5, 2, 4, 5, 6, 4, 

In [7]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape


(torch.Size([207, 9216]),
 torch.Size([207, 18432]),
 torch.Size([20, 9216]),
 torch.Size([20, 18432]))

In [8]:
regressor = Ridge()
regressor.fit(X_train, y_train)

In [9]:
y_train_hat = regressor.predict(X_train)
y_test_hat = regressor.predict(X_test)

In [10]:
mse_train = mean_squared_error(y_train, y_train_hat)
mse_test = mean_squared_error(y_test, y_test_hat)

In [11]:
mse_train, mse_test

(0.006931318842053702, 1.0244137798614488)

In [12]:
images = []
for i in range(y_test.shape[0]):
    images.append(X_test[i].view(18, 512))
    images.append(y_test[i,:18*512].view(18, 512))
    images.append(y_test[i,18*512:].view(18, 512))
    images.append(torch.from_numpy(y_test_hat[i,:18*512]).to(torch.float32).view(18, 512))
    images.append(torch.from_numpy(y_test_hat[i,18*512:]).to(torch.float32).view(18, 512))
images = torch.stack(images, dim=0)

## Visualize and evaluate results

In [13]:
toTensor = torchvision.transforms.PILToTensor()
toPIL = torchvision.transforms.ToPILImage()
eval = BaseEvaluator()
stylegan = StyleGAN2(tmp_path=output_path)

<docker.models.images.ImageCollection object at 0x7f9f37c06400>
Got StyleGAN2 docker client, building image...
StyleGAN2 Docker image built.


In [14]:
pils = stylegan.generate_from_array(images.detach().cpu().numpy())
pil = toPIL(torchvision.utils.make_grid([toTensor(pil.resize((128,128))) for pil in pils], nrow=5)).convert("RGB")
pil.save(output_path + "/result.png")

Generating 100 images from array...
[INFO] StyleGAN2 - Generating image...
Loading networks from "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"...

Traceback (most recent call last):

  File "./submodules/stylegan2-ada-pytorch/generate.py", line 127, in <module>

    generate_images() # pylint: disable=no-value-for-parameter

  File "/opt/conda/lib/python3.8/site-packages/click/core.py", line 829, in __call__

    return self.main(*args, **kwargs)

  File "/opt/conda/lib/python3.8/site-packages/click/core.py", line 782, in main

    rv = self.invoke(ctx)

  File "/opt/conda/lib/python3.8/site-packages/click/core.py", line 1066, in invoke

    return ctx.invoke(self.callback, **ctx.params)

  File "/opt/conda/lib/python3.8/site-packages/click/core.py", line 610, in invoke

    return callback(*args, **kwargs)

  File "/opt/conda/lib/python3.8/site-packages/click/decorators.py", line 21, in new_func

    return f(get_current_context(), *args, **kwargs)

  Fi

Exception: Container exited with the status code 1...

In [36]:
pils_reshaped = np.array([np.array(pil).reshape(1024, 1024, 3) for pil in pils])
pils_reshaped = pils_reshaped.reshape(-1, 5, 1024, 1024, 3)
images_eval_f = pils_reshaped[:, 1]
images_eval_m = pils_reshaped[:, 2]
images_hat_eval_f = pils_reshaped[:, 3]
images_hat_eval_m = pils_reshaped[:, 4]

In [37]:
images_eval_arr_f = [Image.fromarray(i).convert("RGB") for i in images_eval_f]
images_eval_arr_m = [Image.fromarray(i).convert("RGB") for i in images_eval_m]
images_hat_eval_arr_f = [Image.fromarray(i).convert("RGB") for i in images_hat_eval_f]
images_hat_eval_arr_m = [Image.fromarray(i).convert("RGB") for i in images_hat_eval_m]

In [38]:
images_eval_pil_f = toPIL(torchvision.utils.make_grid([toTensor(pil.resize((128, 128))) for pil in images_eval_arr_f], nrow=1)).convert("RGB")
images_eval_pil_f.save(output_path + "/images_eval_f.png")

images_eval_pil_m = toPIL(torchvision.utils.make_grid([toTensor(pil.resize((128, 128))) for pil in images_eval_arr_m], nrow=1)).convert("RGB")
images_eval_pil_m.save(output_path + "/images_eval_m.png")

In [45]:
eval_res_f_fn = eval.evaluate_batch(images_eval_f, images_hat_eval_f, model_name='Facenet512')
images_hat_eval_arr_f_labeled_fn = zip(images_hat_eval_arr_f, eval_res_f_fn)
images_hat_eval_pil_f_fn = toPIL(torchvision.utils.make_grid([toTensor(image_add_label(pil, str(round(label, 3)), 40).resize((256, 256))) for pil, label in images_hat_eval_arr_f_labeled_fn], nrow=1)).convert("RGB")
images_hat_eval_pil_f_fn.save(output_path + "/images_eval_hat_f_fn.png")
eval_res_f_fn

eval_res_m_fn = eval.evaluate_batch(images_eval_m, images_hat_eval_m, model_name='Facenet512')
images_hat_eval_arr_m_labeled_fn = zip(images_hat_eval_arr_m, eval_res_m_fn)
images_hat_eval_pil_m_fn = toPIL(torchvision.utils.make_grid([toTensor(image_add_label(pil, str(round(label, 3)), 40).resize((256, 256))) for pil, label in images_hat_eval_arr_m_labeled_fn], nrow=1)).convert("RGB")
images_hat_eval_pil_m_fn.save(output_path + "/images_eval_hat_m_fn.png")
eval_res_m_fn




[0.33402624846914436,
 0.43684275283856516,
 0.159515057262758,
 0.22805098862428852,
 0.47382434400613477,
 0.2976961162220601,
 0.4468262297577832,
 0.4712513135122438,
 0.03458128400108661,
 0.3743983516671815,
 0.3978984547547002,
 0.16578259659631245,
 0.44721421017624513,
 -0.017200315587250783,
 0.5137009203637347,
 0.22086447394089861,
 -0.11366718482966436,
 0.0744433842380785,
 0.20400588503612396,
 -0.1373116028271647]

In [44]:
eval_res_f_af = eval.evaluate_batch(images_eval_f, images_hat_eval_f, model_name='ArcFace')
images_hat_eval_arr_f_labeled_af = zip(images_hat_eval_arr_f, eval_res_f_af)
images_hat_eval_pil_f_af = toPIL(torchvision.utils.make_grid([toTensor(image_add_label(pil, str(round(label, 3)), 40).resize((256, 256))) for pil, label in images_hat_eval_arr_f_labeled_af], nrow=1)).convert("RGB")
images_hat_eval_pil_f_af.save(output_path + "/images_eval_hat_f_af.png")
eval_res_f_af

eval_res_m_af = eval.evaluate_batch(images_eval_m, images_hat_eval_m, model_name='ArcFace')
images_hat_eval_arr_m_labeled_af = zip(images_hat_eval_arr_m, eval_res_m_af)
images_hat_eval_pil_m_af = toPIL(torchvision.utils.make_grid([toTensor(image_add_label(pil, str(round(label, 3)), 40).resize((256, 256))) for pil, label in images_hat_eval_arr_m_labeled_af], nrow=1)).convert("RGB")
images_hat_eval_pil_m_af.save(output_path + "/images_eval_hat_m_af.png")
eval_res_m_af




[0.10377961691985067,
 0.943256346206428,
 -0.008764100807423717,
 0.21682055050580815,
 0.26986823526496284,
 0.1969762862823228,
 0.019916973462405136,
 0.358388616413088,
 0.0038190441572530888,
 0.3309051924515633,
 0.1511933387718012,
 0.03959055423529361,
 0.5192292305804449,
 0.025049313347839344,
 0.5997104295188008,
 0.10366729960530924,
 0.03381036896741243,
 0.05439476054490021,
 0.023167767168489554,
 0.07833462406799792]

In [43]:
eval_res_vgg = eval.evaluate_batch(images_eval_f, images_hat_eval_f, model_name='VGG-Face')
images_hat_eval_arr_f_labeled_vgg = zip(images_hat_eval_arr_f, eval_res_vgg)
images_hat_eval_pil_f_vgg = toPIL(torchvision.utils.make_grid([toTensor(image_add_label(pil, str(round(label, 3)), 40).resize((256, 256))) for pil, label in images_hat_eval_arr_f_labeled_vgg], nrow=1)).convert("RGB")
images_hat_eval_pil_f_vgg.save(output_path + "/images_eval_hat_f_vgg.png")
eval_res_vgg

eval_res_vgg = eval.evaluate_batch(images_eval_m, images_hat_eval_m, model_name='VGG-Face')
images_hat_eval_arr_m_labeled_vgg = zip(images_hat_eval_arr_m, eval_res_vgg)
images_hat_eval_pil_m_vgg = toPIL(torchvision.utils.make_grid([toTensor(image_add_label(pil, str(round(label, 3)), 40).resize((256, 256))) for pil, label in images_hat_eval_arr_m_labeled_vgg], nrow=1)).convert("RGB")
images_hat_eval_pil_m_vgg.save(output_path + "/images_eval_hat_m_vgg.png")
eval_res_vgg




[0.464754694541439,
 0.8274239937676159,
 0.34938368077529997,
 0.576459335828597,
 0.5537390224322192,
 0.2699485708488388,
 0.46675176195489104,
 0.6830751306542533,
 0.046867572628578304,
 0.7086720534966884,
 0.4166164836708298,
 0.5513433182114134,
 0.6468025871821815,
 0.2652512330690248,
 0.6530968067392857,
 0.6092680018630461,
 0.12331453282012529,
 0.25539691219079025,
 0.3061176826490259,
 0.10899034316594093]