In [1]:
%cd /home/nakamura/ganspace

/home/nakamura/ganspace


In [2]:
from IPython.utils import io
import torch
import PIL
import imageio
import pickle
import os
import numpy as np
import random
import ipywidgets as widgets
import matplotlib.pyplot as plt
from PIL import Image
from models import get_instrumented_model
from decomposition import get_or_compute
from config import Config
from skimage import img_as_ubyte
from ipywidgets import fixed

StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.
StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.


### Load model

In [3]:
# Speed up computation
torch.autograd.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True

# Specify model to use
config = Config(
  model='StyleGAN2',
  layer='style',
  output_class='grain-final',
  components=80,
  use_w=True,
  n=1_000_000,
  batch_size=10_000, # style layer quite small
)

# config = Config(
#   model='StyleGAN2',
#   layer='style',
#   output_class='ffhq',
#   components=80,
#   use_w=True,}
#   n=1_000_000,
#   batch_size=10_000, # style layer quite small
# )


inst = get_instrumented_model(config.model, config.output_class,
                              config.layer, torch.device('cuda'), use_w=config.use_w)

path_to_components = get_or_compute(config, inst)

model = inst.model
with open(path_to_components+'.pickle', mode='rb') as f:
    transformer = pickle.load(f)

named_directions = {} #init named_directions dict to save directions

checkpoint_root /home/nakamura/ganspace/models/checkpoints
checkpoint /home/nakamura/ganspace/models/checkpoints/stylegan2/stylegan2_grain-final_256.pt


## Multiple components
### Load components

In [4]:
comps = np.load(path_to_components)
lst = comps.files
latent_dirs = []
latent_stdevs = []
comp_dir = []
comp_dir_stdev = []

load_activations = True

for item in lst:
    if load_activations:
      if item == 'act_comp':
        for i in range(comps[item].shape[0]):
          latent_dirs.append(comps[item][i])
      if item == 'act_stdev':
        for i in range(comps[item].shape[0]):
          latent_stdevs.append(comps[item][i])
    else:
      if item == 'lat_comp':
        for i in range(comps[item].shape[0]):
          latent_dirs.append(comps[item][i])
      if item == 'lat_stdev':
        for i in range(comps[item].shape[0]):
          latent_stdevs.append(comps[item][i])

num = 6

for i in range(num):
    comp_dir.append(latent_dirs[i])
    comp_dir_stdev.append(latent_stdevs[i])
    
print(f'Loaded Component No. 1~{num}')

Loaded Component No. 1~6


### Run UI

In [5]:
def display_sample_pytorch(seed, truncation, dir0, dir1, dir2, dir3, dir4, dir5, dim0, dim1, dim2, dim3, dim4, dim5, scale, start, end, disp=True, save=None, noise_spec=None):
    # blockPrint()
    # with io.capture_output() as captured:
    param = []
    inv = []
    
    w = model.sample_latent(1, seed=seed).cpu().numpy()

    model.truncation = truncation
    w = [w]*model.get_max_latents() # one per layer
    for l in range(start, end):
        w[l] =  dir0 * dim0 * scale + dir1 * dim1 * scale + dir2 * dim2 * scale + dir3 * dim3 * scale + dir4 * dim4 * scale + dir5 * dim5 * scale
        param.append(transformer.transform(w[l]))
    
#     print(param[0][0][:num])
    
    #save image and display
    out = model.sample_np(w)
    final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((256,256),Image.LANCZOS)

    if disp:
      display(final_im)

    if save is not None:
      if disp == False:
        print(save)
      final_im.save(f'out/{seed}_{save:05}.png')


seed = np.random.randint(0,100000)
style = {'description_width': 'initial'}

seed = widgets.IntSlider(min=0, max=100000, step=1, value=seed, description='Seed: ', continuous_update=False)
truncation = widgets.FloatSlider(min=0, max=2, step=0.1, value=0.7, description='Truncation: ', continuous_update=False)
dim0 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='d1: ', continuous_update=False, style=style)
dim1 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='d2: ', continuous_update=False, style=style)
dim2 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='d3: ', continuous_update=False, style=style)
dim3 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='d4: ', continuous_update=False, style=style)
dim4 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='d5: ', continuous_update=False, style=style)
dim5 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='d6: ', continuous_update=False, style=style)
scale = widgets.FloatSlider(min=0, max=10, step=0.05, value=1, description='Scale: ', continuous_update=False)
start_layer = widgets.IntSlider(min=0, max=model.get_max_latents(), step=1, value=0, description='start layer: ', continuous_update=False)
end_layer = widgets.IntSlider(min=0, max=model.get_max_latents(), step=1, value=18, description='end layer: ', continuous_update=False)

# Make sure layer range is valid
def update_range_start(*args):
  end_layer.min = start_layer.value
def update_range_end(*args):
  start_layer.max = end_layer.value
start_layer.observe(update_range_start, 'value')
end_layer.observe(update_range_end, 'value')

text = widgets.Text(description="Name component here", style=style, width=200)

bot_box = widgets.VBox([dim0, dim1, dim2, dim3, dim4, dim5, scale, start_layer, end_layer, seed])
ui = widgets.VBox([bot_box])

out = widgets.interactive_output(display_sample_pytorch, {'seed': seed, 'truncation': truncation, 'dir0': fixed(comp_dir[0]), 'dir1': fixed(comp_dir[1]), 'dir2': fixed(comp_dir[2]), 'dir3': fixed(comp_dir[3]), 'dir4': fixed(comp_dir[4]), 'dir5': fixed(comp_dir[5]), 'dim0': dim0, 'dim1': dim1, 'dim2': dim2, 'dim3': dim3, 'dim4': dim4, 'dim5': dim5, 'scale': scale, 'start': start_layer, 'end': end_layer})

display(out, ui)

Output()

VBox(children=(VBox(children=(FloatSlider(value=0.0, continuous_update=False, description='d1: ', max=10.0, mi…

In [53]:
def display_sample_pytorch(seed, truncation, dir0, dir1, dir2, dir3, dir4, dir5, dim0, dim1, dim2, dim3, dim4, dim5, scale, start, end, disp=True, save=None, noise_spec=None):
    # blockPrint()
    # with io.capture_output() as captured:
    param = []
    inv = []
    
    w = model.sample_latent(1, seed=seed).cpu().numpy()

    model.truncation = truncation
    w = [w]*model.get_max_latents() # one per layer
    for l in range(start, end):
        w[l] = w[l] + dir0 * dim0 * scale + dir1 * dim1 * scale + dir2 * dim2 * scale + dir3 * dim3 * scale + dir4 * dim4 * scale + dir5 * dim5 * scale
        # param.append(transformer.transform(w[l]))
        param.append(np.random.rand(1,6))
        inv.append(transformer.inverse_transform(param[l]))
        w[l][0] = inv[l][0].astype(np.float32)
        
    print(param[0][0][:num])

    #save image and display
    out = model.sample_np(w)
    final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((256,256),Image.LANCZOS)

    if disp:
      display(final_im)

    if save is not None:
      if disp == False:
        print(save)
      final_im.save(f'out/{seed}_{save:05}.png')


seed = np.random.randint(0,100000)
style = {'description_width': 'initial'}

seed = widgets.IntSlider(min=0, max=100000, step=1, value=seed, description='Seed: ', continuous_update=False)
truncation = widgets.FloatSlider(min=0, max=2, step=0.1, value=0.7, description='Truncation: ', continuous_update=False)
dim0 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Comp1: ', continuous_update=False, style=style)
dim1 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Comp2: ', continuous_update=False, style=style)
dim2 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Comp3: ', continuous_update=False, style=style)
dim3 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Comp4: ', continuous_update=False, style=style)
dim4 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Comp5: ', continuous_update=False, style=style)
dim5 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Comp6: ', continuous_update=False, style=style)
scale = widgets.FloatSlider(min=0, max=10, step=0.05, value=1, description='Scale: ', continuous_update=False)
start_layer = widgets.IntSlider(min=0, max=model.get_max_latents(), step=1, value=0, description='start layer: ', continuous_update=False)
end_layer = widgets.IntSlider(min=0, max=model.get_max_latents(), step=1, value=18, description='end layer: ', continuous_update=False)

# Make sure layer range is valid
def update_range_start(*args):
  end_layer.min = start_layer.value
def update_range_end(*args):
  start_layer.max = end_layer.value
start_layer.observe(update_range_start, 'value')
end_layer.observe(update_range_end, 'value')

text = widgets.Text(description="Name component here", style=style, width=200)

bot_box = widgets.VBox([seed, dim0, dim1, dim2, dim3, dim4, dim5, scale, start_layer, end_layer])
ui = widgets.VBox([bot_box])

out = widgets.interactive_output(display_sample_pytorch, {'seed': seed, 'truncation': truncation, 'dir0': fixed(comp_dir[0]), 'dir1': fixed(comp_dir[1]), 'dir2': fixed(comp_dir[2]), 'dir3': fixed(comp_dir[3]), 'dir4': fixed(comp_dir[4]), 'dir5': fixed(comp_dir[5]), 'dim0': dim0, 'dim1': dim1, 'dim2': dim2, 'dim3': dim3, 'dim4': dim4, 'dim5': dim5, 'scale': scale, 'start': start_layer, 'end': end_layer})

display(out, ui)

Output()

VBox(children=(VBox(children=(IntSlider(value=6692, continuous_update=False, description='Seed: ', max=100000)…

## Single component
### Load a component

In [5]:
comps = np.load(path_to_components)
lst = comps.files
latent_dirs = []
latent_stdevs = []
comp_dir = []
comp_dir_stdev = []

load_activations = True

for item in lst:
    if load_activations:
      if item == 'act_comp':
        for i in range(comps[item].shape[0]):
          latent_dirs.append(comps[item][i])
      if item == 'act_stdev':
        for i in range(comps[item].shape[0]):
          latent_stdevs.append(comps[item][i])
    else:
      if item == 'lat_comp':
        for i in range(comps[item].shape[0]):
          latent_dirs.append(comps[item][i])
      if item == 'lat_stdev':
        for i in range(comps[item].shape[0]):
          latent_stdevs.append(comps[item][i])
            
#load one at random 
# num = np.random.randint(20)
# if num in named_directions.values():
#   print(f'Direction already named: {list(named_directions.keys())[list(named_directions.values()).index(num)]}')

num = 10

# comp_dir = latent_dirs[num]
# comp_dir_stdev = latent_stdevs[num]

comp_dir.append(latent_dirs[num])
comp_dir_stdev.append(latent_stdevs[num])
    
print(f'Loaded Component No. {num}')


Loaded Component No. 10


### Run UI

In [7]:
def display_sample_pytorch(seed, truncation, dir0, dim0, scale, start, end, disp=True, save=None, noise_spec=None):
    # blockPrint()
    # with io.capture_output() as captured:
    param = []
    w = model.sample_latent(1, seed=seed).cpu().numpy()

    model.truncation = truncation
    w = [w]*model.get_max_latents() # one per layer
    for l in range(start, end):
      w[l] = w[l] + dir0 * dim0 * scale 
      param.append(transformer.transform(w[l]))
        
#     print(param[0][0][num])

    #save image and display
    out = model.sample_np(w)
    final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((256,256),Image.LANCZOS)

    if disp:
      display(final_im)
    if save is not None:
      if disp == False:
        print(save)
      final_im.save(f'out/{seed}_{save:05}.png')


seed = np.random.randint(0,100000)
style = {'description_width': 'initial'}

seed = widgets.IntSlider(min=0, max=100000, step=1, value=seed, description='Seed: ', continuous_update=False)
truncation = widgets.FloatSlider(min=0, max=2, step=0.1, value=0.7, description='Truncation: ', continuous_update=False)
dim0 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='SelectedComp: ', continuous_update=False, style=style)
scale = widgets.FloatSlider(min=0, max=10, step=0.05, value=1, description='Scale: ', continuous_update=False)
start_layer = widgets.IntSlider(min=0, max=model.get_max_latents(), step=1, value=0, description='start layer: ', continuous_update=False)
end_layer = widgets.IntSlider(min=0, max=model.get_max_latents(), step=1, value=18, description='end layer: ', continuous_update=False)

# Make sure layer range is valid
def update_range_start(*args):
  end_layer.min = start_layer.value
def update_range_end(*args):
  start_layer.max = end_layer.value
start_layer.observe(update_range_start, 'value')
end_layer.observe(update_range_end, 'value')

text = widgets.Text(description="Name component here", style=style, width=200)

bot_box = widgets.VBox([seed, dim0, scale, start_layer, end_layer])
ui = widgets.VBox([bot_box])

out = widgets.interactive_output(display_sample_pytorch, {'seed': seed, 'truncation': truncation, 'dir0': fixed(comp_dir[0]), 'dim0': dim0, 'scale': scale, 'start': start_layer, 'end': end_layer})

display(out, ui)


Output()

VBox(children=(VBox(children=(IntSlider(value=47596, continuous_update=False, description='Seed: ', max=100000…

# Convert w to param

In [5]:
w_test = np.load('questionnaire/w_test.npy')
print(w_test.shape)
print(w_test)
param_test = transformer.transform(w_test)
print(param_test.shape)
print(param_test)
np.save('questionnaire/param_test.npy', param_test)

(34, 512)
[[ 0.3440561  -0.950136    0.5900445  ... -0.62184376 -0.43264222
   1.2390425 ]
 [ 0.37139365 -0.6246364   0.80698174 ... -1.2422309  -1.9276294
  -0.23519118]
 [-0.09118952 -0.20676489  1.1669356  ... -0.7095085   1.6524553
   0.98981756]
 ...
 [ 0.79115605 -0.30226016 -0.20428316 ... -0.8007261   0.14960529
   0.93486595]
 [ 1.0991617  -0.961632    0.8275591  ...  0.46156552 -1.6141889
   0.42750245]
 [ 0.14710009 -1.0717025   0.20085426 ... -0.7783195   0.9358142
   0.01562361]]
