In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.auto import tqdm


import os
import sys
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torchvision.utils import save_image
import torchvision.transforms as transforms

from IPython.display import display, HTML

sys.path.append('../../src/')
import modeling_utils
import models
import optimization_utils


In [None]:
# showing all the figures and plots and stuff makes github unhappy :(

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
torch.cuda.is_available()

In [None]:
torch.__version__

In [None]:
def load_model(depth, fade_in, step):
    gen = models.StyleGanGenerator(6, noise_size).to(device)
    gen.load_state_dict(torch.load('big_run/models/gen_ema_depth_%d_fade_%d_step_%d.pt'%(depth, fade_in, step)))
    gen = gen.eval()
    return gen

## Hyperparamaters

In [None]:
noise_size = 512
max_depth = 6

In [None]:
device = torch.device('cuda')

In [None]:
gen = models.StyleGanGenerator(max_depth, noise_size).to(device)

In [None]:
depth = max_depth-1
load_fade_in = 100
load_step = 638
# gen.load_state_dict(torch.load(load_path + 'models/gen_ema_depth_%d_fade_%d_step_%d.pt'%(depth, load_fade_in, load_step)))
gen = load_model(depth, load_fade_in, load_step)

In [None]:
gen.eval()
''

# Plot some random samples

In [None]:
n_plots = 10
for _ in range(n_plots):
    with torch.no_grad():
        modeling_utils.plot_imgs(modeling_utils.sample_gen_images(gen, noise_size, device, depth=max_depth-1, alpha=1))
        print('='*100)

# Images w/ fixed noise
Some good seeds are: 14, 15, 32, 36, 41, 46, 66

In [None]:
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**128

In [None]:
def parse_name(filename):
    parts = filename.split('_')
    step = parts[-1].split('.')[0]
    fade = parts[-3]
    depth = parts[-5]
    prefix = filename[:filename.find('_depth')]
    return prefix, int(depth), int(fade), int(step)

In [None]:
def make_path(d, f, s):
    return 'gen_ema_depth_%d_fade_%d_step_%d.pt'%(d,f,s)

def get_next_gen_ema(depth, fade, step, n_forward_steps = 5, n_forward_fade=5):
    for f in range(n_forward_fade):
        for s in range(n_forward_steps):
            if f==0 and s==0: continue
            check_path = make_path(depth, fade + f, step + s)
            if os.path.exists('big_run/models/' + check_path):
                return check_path
            
def next_gen_close_time(depth, fade, step, time_thresh=120):
    next_path = get_next_gen_ema(depth, fade, step)
    if not next_path:
        return False
    this_time = os.path.getmtime('big_run/models/' + make_path(depth, fade, step))
    other_time = os.path.getmtime('big_run/models/' + next_path)
    if abs(other_time - this_time) < time_thresh:
        return True
    return False

In [None]:
def make_fixed_noise_gif(seed):
    fixed_noise = modeling_utils.generate_noise(1, noise_size, device, seed=seed)
    torch.manual_seed(seed)
    per_channel_noise = torch.randn(fixed_noise.size(0), 2*max_depth, 128, 128).to(device)
    
    files = os.listdir('big_run/models/')
    files.sort(key = lambda x: os.path.getmtime('big_run/models/' + x))
    
    pbar = tqdm(total = len(files)//5)
    fixed_imgs = []
    for f in files:
        prefix, depth, fade, step = parse_name(f)
        if not prefix=='gen_ema' : continue
        # During training for depth==5, the last checkpoint was saved twice on accident. So skip the extra checkpoints
        if depth==5 and next_gen_close_time(depth, fade, step):
            continue
        tmp_model = load_model(depth, fade, step)
        this_img = tmp_model(fixed_noise, depth, fade/100, per_channel_noise=per_channel_noise)
        this_img = modeling_utils.swap_channels_batch(this_img)
        this_img = modeling_utils.post_model_process(this_img).squeeze()
        fixed_imgs.append(this_img)
        pbar.update(1)
        
    def show_img(img, imobj, ax):
        ax.set_xticks([])
        ax.set_yticks([])

        imobj.set_data(img)
        return imobj
    
    fig = plt.figure()
    fig.set_size_inches(5,5)
    fig.tight_layout()
    ax = plt.gca()
    imobj = ax.imshow(np.zeros(fixed_imgs[-1].shape))
    animate_from_idx = lambda i: show_img(fixed_imgs[i], imobj=imobj, ax=ax)
    ani = animation.FuncAnimation(fig, animate_from_idx, init_func=lambda: None, frames=len(fixed_imgs), repeat_delay=5000, interval=50)
    vid = HTML(ani.to_jshtml())
    
    return ani, vid

In [None]:
# ani_14, vid_14 = make_fixed_noise_gif(14)

In [None]:
# ani_15, vid_15 = make_fixed_noise_gif(15)

In [None]:
# ani_32, vid_32 = make_fixed_noise_gif(32)

In [None]:
# ani_36, vid_36 = make_fixed_noise_gif(36)

In [None]:
# ani_41, vid_41 = make_fixed_noise_gif(41)

In [None]:
# ani_46, vid_46 = make_fixed_noise_gif(46)

In [None]:
# ani_66, vid_66 = make_fixed_noise_gif(66)

In [None]:
# vid_14

In [None]:
# vid_15

In [None]:
# vid_32

In [None]:
# vid_36

In [None]:
# vid_41

In [None]:
# vid_46

In [None]:
# vid_66

### Mixing

In [None]:
def show_style_mixing(gen, noise_size, depth, src_seeds, dest_seeds, swap_range, title=None):
    def add_img_to_ax(ax, img, dim=False):
        if dim:
            ax.imshow(img, alpha=.5)
        else:
            ax.imshow(img)
        ax.set_xticks([])
        ax.set_yticks([])
        
    
    num_cols_minus_1 = len(src_seeds)
    num_rows_minus_1 = len(dest_seeds)
    
    fig, ax = plt.subplots(num_rows_minus_1 + 1, num_cols_minus_1 + 1)
    fig.set_size_inches(2*num_rows_minus_1+2, 2*num_cols_minus_1 + 2)
#     ax[0,0].set_xticks([])
#     ax[0,0].set_yticks([])
    ax[0,0].axis("off")
    
    unique_seeds = list(set(src_seeds + dest_seeds))
    torch.manual_seed(0)
    per_channel_noise = torch.randn(len(unique_seeds), 2*max_depth, 128, 128).to(device)
    
    per_channel_noise_src_indexer = torch.tensor([unique_seeds.index(s) for s in src_seeds]).to(device).long()
    per_channel_noise_dest_indexer = torch.tensor([unique_seeds.index(s) for s in dest_seeds]).to(device).long()
    per_channel_noise_src = per_channel_noise[per_channel_noise_src_indexer]
    per_channel_noise_dest = per_channel_noise[per_channel_noise_dest_indexer]
    
    if title is not None:
        fig.suptitle(title)
    
    with torch.no_grad():
        src_latents = torch.cat([modeling_utils.generate_noise(1, noise_size, device, seed=s) for s in src_seeds])
        dest_latents = torch.cat([modeling_utils.generate_noise(1, noise_size, device, seed=s) for s in dest_seeds])
        src_latents_for_synth = gen.mapping_layers(src_latents)
        dest_latents_for_synth = gen.mapping_layers(dest_latents)
        
        src_images = modeling_utils.swap_channels_batch(gen.synthesis_layers(src_latents_for_synth, depth=depth, alpha=1, per_channel_noise = per_channel_noise_src))
        dest_images = modeling_utils.swap_channels_batch(gen.synthesis_layers(dest_latents_for_synth, depth=depth, alpha=1, per_channel_noise = per_channel_noise_dest))
        
        src_images = modeling_utils.post_model_process(src_images)
        dest_images = modeling_utils.post_model_process(dest_images)
        
        for i in range(1, num_cols_minus_1+1):
            add_img_to_ax(ax[0,i], src_images[i-1])
            ax[0,i].imshow(src_images[i-1])
        for i in range(1, num_rows_minus_1+1):
            add_img_to_ax(ax[i,0],  dest_images[i-1])
            
        for i in range(len(dest_latents_for_synth)):
            new_dest_latents_for_synth = dest_latents_for_synth[i].unsqueeze(0).repeat(num_cols_minus_1, 1, 1)
            new_dest_latents_for_synth[:, swap_range] = src_latents_for_synth[:, swap_range]
            
            new_dest_per_channel_noise = per_channel_noise_dest[i].unsqueeze(0).repeat(num_cols_minus_1, 1, 1, 1)
            new_dest_per_channel_noise[:, swap_range] = per_channel_noise_src[:, swap_range]
            
            row_images = modeling_utils.swap_channels_batch(gen.synthesis_layers(new_dest_latents_for_synth, depth=depth, alpha=1, per_channel_noise = new_dest_per_channel_noise))
            row_images = modeling_utils.post_model_process(row_images)
            for j, img in enumerate(row_images):
                add_img_to_ax(ax[i+1,j+1], row_images[j], dest_seeds[i]==src_seeds[j])

In [None]:
show_style_mixing(gen, noise_size, max_depth-1, [2,5,32,33], [2,5,32,33], swap_range = np.arange(4), title='Style Swap Resolutions: 4, 8')

In [None]:
show_style_mixing(gen, noise_size, max_depth-1, [2,5,32,33], [2,5,32,33], swap_range = np.arange(4)+4, title='Style Swap Resolutions: 16, 32')

In [None]:
show_style_mixing(gen, noise_size, max_depth-1, [2,5,32,33], [2,5,32,33], swap_range = np.arange(4)+8, title='Style Swap Resolutions: 64, 128')

### Interpolation

In [None]:
from IPython.core.display import HTML
import matplotlib.animation as animation
import matplotlib.pyplot as plt
def show_interpolation(gen, noise_size, depth, seeds, mode='first', steps_per_interp=50, save_path=None, loop=False, **ani_kwargs):
    fig = plt.gcf()
    fig.set_size_inches(4,4)
    imgs = []
    latents = torch.cat([modeling_utils.generate_noise(1, noise_size, device, seed=s) for s in seeds])
    
    per_channel_noise = torch.randn(1, 2*max_depth, 128, 128).to(device)
    
    with torch.no_grad():
        max_range = len(latents) if loop else len(latents)-1
        pbar = tqdm(total = steps_per_interp*max_range, leave=False)
        for i in range(max_range):
            start = latents[i]
            end_idx = i+1 if i < len(latents)-1 else 0
            end = latents[end_idx]
            if not mode=='first':
                start_latents_for_synth = gen.mapping_layers(start.unsqueeze(0))
                end_latents_for_synth = gen.mapping_layers(end.unsqueeze(0))
                if not isinstance(mode, torch.Tensor):
                    mode = torch.tensor(mode).to(device).long()
            for interp_coeff in np.linspace(0,1,steps_per_interp):
                if mode=='first':
                    model_input = ((1-interp_coeff)*start + interp_coeff*end).unsqueeze(0)
                    img = gen(model_input, depth=depth, alpha=1, per_channel_noise=per_channel_noise)
                else:
                    synth_input = start_latents_for_synth
                    synth_input[:,mode] = (1-interp_coeff)*start_latents_for_synth[:,mode] + interp_coeff*end_latents_for_synth[:,mode]
                    img = gen.synthesis_layers(synth_input, depth, alpha=1, per_channel_noise=per_channel_noise)
                    
                img = modeling_utils.swap_channels_batch(img)
                img = modeling_utils.post_model_process(img).squeeze()
                
                imgs.append(img)
                pbar.update(1)
    pbar.close()
    ax = plt.gca()
    ax.set_xticks([])
    ax.set_yticks([])
    imgs = [[plt.imshow(im, animated=True)] for im in imgs]
    ani = animation.ArtistAnimation(fig, imgs, **ani_kwargs)
    return HTML(ani.to_jshtml()), ani

In [None]:
# video_latent, ani_latent = show_interpolation(gen, noise_size, max_depth-1, [14, 15, 36, 41, 46, 66], mode=np.arange(12), steps_per_interp=50, 
#                          interval=128, blit=True, repeat_delay=100, loop=True)

In [None]:
# video_latent

In [None]:
# video_first, ani_first = show_interpolation(gen, noise_size, max_depth-1, [14, 15, 36, 41, 46, 66], mode='first', steps_per_interp=200, 
#                          interval=33, blit=True, repeat_delay=400, loop=True)

In [None]:
# video_first

In [None]:
ani_first.save('../../results/latent_interpolation/noise_vector_interpolation.gif')

## Plot the FID curve during training

In [None]:
fid_df = pd.read_csv('big_run/fid_v2/fid_v2.csv')

In [None]:
plt.title('FID during Training')
fid_df['fid'].plot()
plt.xlabel('Quarter Epoch (x=0: First Epoch @ 128x128 resolution)')
plt.ylabel('FID')
plt.show()

In [None]:
fid_df[fid_df['fid']==fid_df['fid'].min()]