In [2]:
import os
import random
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image

In [65]:
img_dir = 'imgs_align'
num_files = 0
for i,file in enumerate(os.listdir(img_dir)):
    print(f'{i:2}: {file}')
    num_files += 1
print(f'\ntotal: {num_files}')

 0: azuki_default.png
 1: ceo_default.png
 2: check1.png
 3: check2.png
 4: danda_default.png
 5: detective.png
 6: detective_02.png
 7: idPhoto.png
 8: mio_happy.png
 9: mio_shock.png
10: mio_silence.png
11: mio_u.png
12: nanko_default.png
13: ookawa_angry.png
14: ookawa_angry2.png
15: ookawa_default.png
16: ookawa_high.png
17: ookawa_regret.png
18: ookawa_smile.png
19: ookawa_surprised.png
20: pharmacist.png
21: saki.png
22: saki_glasses.png
23: sandy.png
24: takebe_default.png
25: test01.png
26: test01_02.png
27: test01_03.png
28: test01_04.png
29: test02.png
30: woman_default.png
31: yotaka_angry.png
32: yotaka_angry2.png
33: yotaka_bald.png
34: yotaka_bushy.png
35: yotaka_default.png
36: yotaka_gj.png
37: yotaka_smile.png
38: yotaka_smile2.png

total: 39


In [66]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [67]:
avg_face1 = np.zeros((256,256,3)).astype(np.uint64)
avg_face2 = torch.zeros(3,256,256)
for i,file in enumerate(os.listdir(img_dir)):
    img1 = cv2.imread(f'{img_dir}/{file}')
    avg_face1 += img1
    img2 = Image.open(f'{img_dir}/{file}')
    img2 = transform(img2)
    avg_face2 += img2

avg_face1 = avg_face1 // num_files
avg_face1 = avg_face1.astype(np.uint8)
avg_face2 = avg_face2 / num_files
avg_face2 = (avg_face2+1)/2

In [69]:
def cos(a, b):
    a = a.view(-1)
    b = b.view(-1)
    a = F.normalize(a, dim=0)
    b = F.normalize(b, dim=0)
    return (a * b).sum()

def interpolate(img1, img2):
    img_shape = img1.shape
    theta = torch.arccos(cos(img1, img2))
    img_avg = (torch.sin(0.5 * theta) * img1.flatten(0, 2) + torch.sin(0.5 * theta) * img2.flatten(0, 2)) / torch.sin(theta)
    img_avg = img_avg.view(*img_shape)
    return img_avg

def make_average_face(imgs):
    if len(imgs) < 2:
        return imgs[0]
    random.shuffle(imgs)
    imgs2 = []
    for i in range(0, len(imgs)-1, 2):
        img_avg = interpolate(imgs[i], imgs[i+1])
        imgs2.append(img_avg)
    if len(imgs) > 2 and len(imgs)//2 != 0:
        imgs2.append(imgs[-1])
    return make_average_face(imgs2)    

In [79]:
imgs = []
for file in os.listdir(img_dir):
    img = Image.open(f'{img_dir}/{file}')
    img = transform(img)
    imgs.append(img)

avg_face3 = make_average_face(imgs)
avg_face3 = (avg_face3+1)/2

In [80]:
dst_dir = 'imgs_test/imgs_avg/'
if not os.path.exists(dst_dir): os.makedirs(dst_dir)
cv2.imwrite(f'{dst_dir}avg_face1.png', avg_face1)
save_image(avg_face2, f'{dst_dir}avg_face2.png', format='PNG')
save_image(avg_face3, f'{dst_dir}avg_face3.png', format='PNG')