In [25]:
import os
import sys
import dnnlib
from dnnlib import tflib 
import random
from pathlib import Path
from PIL import Image
import pickle
import numpy as np
import torch
import ipywidgets as widgets
from IPython.display import display
from tqdm import tqdm
from ipywidgets import interact
from ipywidgets import HBox, VBox
import matplotlib.pyplot as plt
import projector
import warnings
import csv

warnings.filterwarnings(action='ignore')


def load_model(pklfilename):
    path = pklfilename
    with open(path, 'rb') as f:
        Gs = pickle.load(f)['G_ema'].cuda() # torch.nn.Module

    noise_vars = {name: buf for (name, buf) in Gs.synthesis.named_buffers() if 'noise_const' in name}
    Gs_kwargs = dnnlib.util.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False

    return Gs, noise_vars, Gs_kwargs

def calculate_latent_vectors(Gs, noise_vars, seeds):
    sum_avg = 0
    n = len(seeds)
    alpha = 1 / n
    for i in range(n):
        img, latent_code = generate_image_random(Gs, noise_vars, seeds[i])
        avg = latent_code * alpha
        sum_avg += avg

    averaged_code = sum_avg

    return averaged_code

def generate_average_image(rndnumbers, adj, pklfilename):
    Gs, noise_vars, Gs_kwargs = load_model(pklfilename)
    adj_folder = adj + '/'
    seeds = []
#     list_filename = os.listdir(adj_folder)
#     seeds = []
#     for file in list_filename:
#         if (file.count(".png") == 1):
#             randval = file.split('.')[0]
#             if(randval.isalpha()==False):
#                 seeds.append(int(randval))
    for rndnumber in rndnumbers:
        if rndnumber is not '':
            seeds.append(int(rndnumber))

    averaged_code = calculate_latent_vectors(Gs, noise_vars, seeds)
    c = None
    image_avg = Gs(averaged_code, c)
    img = (image_avg.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    Image.fromarray(img[0].cpu().numpy(), 'RGB').save('chair_image/' + adj_folder + adj + '.png')

def generate_image_random(Gs, noise_vars, rand_seed):
    
    z = torch.from_numpy(np.random.RandomState(rand_seed).randn(1, Gs.z_dim)).to(device)
    c = None # category
    images = Gs(z, c)
    # img = (modern_interpolate_img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    # Image.fromarray(img[0].cpu().numpy(), 'RGB').save('out/modern/average_modern.png')
    return images, z
    
def generate_image_from_projected_latents(latent_vector):
    image = Gs.synthesis(latent_vector)    
    image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    image = Image.fromarray(image[0].cpu().numpy(), 'RGB')
    return image

def get_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

def apply_latent_controls(self):
    image_outputs = controller.children[0]
    feature_sliders = controller.children[1]
    
    slider_hboxes = feature_sliders.children[:-2]
    latent_movements = [x.children[0].value for x in slider_hboxes]

    all_w1 = torch.tensor(latent_code1, device=device)
    w_avg = ws_image_to_use
    modified_latent_code = w_avg + ((all_w1 - w_avg) * ((latent_movements[0])/100))
    
    for x in range(1, len(latent_code_list)):
        all_wx = torch.tensor(latent_code_list[x], device=device)    
        w_avg = ws_image_to_use
        modified_latent_code += ((all_wx - w_avg) * ((latent_movements[x])/100))
    
    latent_img = generate_image_from_projected_latents(modified_latent_code)
    latent_img_output = image_outputs.children[1]
    
    with latent_img_output:
        latent_img_output.clear_output()
        display(latent_img)

def reset_latent_controls(self):
    image_outputs = controller.children[0]
    feature_sliders = controller.children[1]
    
    slider_hboxes = feature_sliders.children[:-2]
    for x in slider_hboxes:
        x.children[0].value = 0
        
    latent_img_output = image_outputs.children[1]
    with latent_img_output:
        latent_img_output.clear_output()
        display(image_to_use)

def create_interactive_latent_controller():
    orig_img_output = widgets.Output()
    with orig_img_output:
        orig_img_output.clear_output()
        display(image_to_use)

    latent_img_output = widgets.Output()

    with latent_img_output:
        latent_img_output.clear_output()
        display(image_to_use)

    image_outputs = widgets.VBox([orig_img_output, latent_img_output])

    # collapse-hide
    generate_button = widgets.Button(description='Generate', layout=widgets.Layout(width='50%', height='30%'))
    generate_button.on_click(apply_latent_controls)

    reset_button = widgets.Button(description='Reset Controls', layout=widgets.Layout(width='50%', height='30%'))
    reset_button.on_click(reset_latent_controls)

    feature_sliders = []
    for feature in latent_controls:
        label = widgets.Label(feature)
        slider = widgets.FloatSlider(min=-50, max=50)
        feature_sliders.append(widgets.HBox([slider, label]))
    feature_sliders.append(generate_button)
    feature_sliders.append(reset_button)
    feature_sliders = widgets.VBox(feature_sliders)

    return widgets.HBox([image_outputs, feature_sliders])

device = torch.device('cuda')
################ Input - Start ####################
# path='C:/PythonWorkspace/stylegan2-ada-pytorch-main/out/chair/'
path = ''
pklfilename = 'network-snapshot-003400.pkl'
csvfilename = 'mapping_result.csv'
results_size = 64
ref_size=64
adj = []
rndnumbers = []
f = open(csvfilename, 'r', encoding='utf-8-sig')
rdr = csv.reader(f)
for line in rdr:
    adj.append(line[0])
    rndnumbers.append(line[1:])
        
# adj = ['딱딱한', '심플한', '모던한', '폭신한', '앤틱한']
################ Input - End####################
Gs, noise_vars, Gs_kwargs = load_model(pklfilename)
# Latent space 평균 이미지 생성
z_samples = np.random.RandomState(123).randn(1000, Gs.z_dim)
w_samples = Gs.mapping(torch.from_numpy(z_samples).to(device), None)
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) 
w_avg = np.mean(w_samples, axis=0, keepdims=True)
w_std = (np.sum((w_samples - w_avg) ** 2) / 10000) ** 0.5
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True)

for step in range(1000):
    # Learning rate schedule.
    t = step / 1000
    w_noise_scale = w_std * 0.05 * max(0.0, 1.0 - t / 0.75) ** 2

w_noise = torch.randn_like(w_opt) * w_noise_scale
ws_image_to_use = (w_opt + w_noise).repeat([1, Gs.mapping.num_ws, 1])
synth_images = Gs.synthesis(ws_image_to_use, noise_mode='const')
image_to_use = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image_to_use = Image.fromarray(image_to_use[0].cpu().numpy(), 'RGB')

# 각 adj 평균 이미지 생성
for i in range(len(adj)):
    if not os.path.exists('chair_image/' + adj[i]):
        os.makedirs('chair_image/' + str(adj[i]))
    generate_average_image(rndnumbers[i], adj[i], pklfilename)

# 각 adj npz파일 생성
for i in range(len(adj)):
    projector.run_projection(pklfilename, 'chair_image/' + adj[i] + '/' + adj[i] +'.png',
                            'chair_image/' + adj[i] + '/', False, 303, 1000)

targetName = 'projected_w.npz'

latent_code1 = np.load('chair_image/' + adj[0] + '/' + targetName)['w']
latent_code2 = np.load('chair_image/' + adj[1] + '/' + targetName)['w']
latent_code3 = np.load('chair_image/' + adj[2] + '/' + targetName)['w']
latent_code4 = np.load('chair_image/' + adj[3] + '/' + targetName)['w']
latent_code5 = np.load('chair_image/' + adj[4] + '/' + targetName)['w']


latent_controls = [adj[0], adj[1], adj[2], adj[3], adj[4]]
latent_code_list = [latent_code1, latent_code2, latent_code3, latent_code4, latent_code5]


controller = create_interactive_latent_controller()
controller



File 2 Loading networks from "network-snapshot-003400.pkl"...
Elapsed: 61.7 s
File 2 Loading networks from "network-snapshot-003400.pkl"...
Elapsed: 61.2 s
File 2 Loading networks from "network-snapshot-003400.pkl"...
Elapsed: 61.9 s
File 2 Loading networks from "network-snapshot-003400.pkl"...
Elapsed: 60.9 s
File 2 Loading networks from "network-snapshot-003400.pkl"...
Elapsed: 60.9 s


HBox(children=(VBox(children=(Output(), Output())), VBox(children=(HBox(children=(FloatSlider(value=0.0, max=5…