In [None]:
import functools
import os
import json
import math
import subprocess
import numpy as np
import sys
from ipywidgets import interact
import matplotlib.pyplot as plt
import torch

# import pretorched.visualizers as vutils
import torchvision
import io
import IPython.display
import PIL
from pprint import pprint
import nb_utils

sys.path.insert(0, '../')

import models
import utils


# Plotting
%matplotlib inline
plt.rcParams['font.size'] = 18.0
plt.rcParams['figure.figsize'] = (16.0, 16.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# For MB Pro retina display
%config InlineBackend.figure_format = 'retina'

# For auto-reloading external modules
%load_ext autoreload
%autoreload 2

### List models

In [None]:
WEIGHTS_ROOT = 'weights/'
weight_dirs = []
for i, w in enumerate(os.listdir(WEIGHTS_ROOT)):
    weight_dirs.append(os.path.join(WEIGHTS_ROOT, w))
    print(i, w, end='\n')

In [None]:
# Select model to sample
wd = weight_dirs[0]

In [None]:
device = 'cuda'
state_dict = torch.load(os.path.join(wd, 'state_dict.pth'))
config = state_dict['config']

model = getattr(models, config['model'])
G = model.Generator(**config).to(device).eval()
G.load_state_dict(torch.load(os.path.join(wd, 'G_ema.pth'), map_location=device))

In [None]:
state_dict = torch.load(os.path.join(wd, 'state_dict.pth'))
pprint(state_dict)

In [None]:
batch_size = 128
z = nb_utils.truncated_z_sample(batch_size, G.dim_z, device=device, truncation=0.8)
y = torch.randint(G.n_classes, (batch_size,), device=device).long()
g = functools.partial(G, embed=True)
with torch.no_grad():
    out = utils.elastic_gan(g, z, y)
nb_utils.imshow(255* np.transpose(torchvision.utils.make_grid(out.cpu(), nrow=int(np.sqrt(batch_size)), normalize=True).numpy(), (1, 2, 0)))

In [None]:
batch_size = 128

z, _ = utils.prepare_z_y(batch_size, G.dim_z, G.n_classes, device=device, z_var=0.45)
z = nb_utils.truncated_z_sample(batch_size, G.dim_z, device=device, truncation=0.2)
y = torch.randint(G.n_classes, (batch_size,), device=device).long()

g = functools.partial(G, embed=True)
with torch.no_grad():
    out = utils.elastic_gan(g, z, y)

nb_utils.imshow(255* np.transpose(torchvision.utils.make_grid(out.cpu(), nrow=int(np.sqrt(batch_size)), normalize=True).numpy(), (1, 2, 0)))

In [None]:
batch_size = 128
z, _ = utils.prepare_z_y(batch_size, G.dim_z, G.n_classes, device=device, z_var=0.1)
z = nb_utils.truncated_z_sample(batch_size, G.dim_z, device=device, truncation=0.2)
y = torch.randint(G.n_classes, (batch_size,), device=device).long()
g = functools.partial(G, embed=True)
with torch.no_grad():
    out = utils.elastic_gan(g, z, y)
nb_utils.imshow(255* np.transpose(torchvision.utils.make_grid(out.cpu(), nrow=int(np.sqrt(batch_size)), normalize=True).numpy(), (1, 2, 0)))

In [None]:
# Intra-class (z only) Latent space interpolation

label = 2
num_samples = 32
num_midpoints = 8
minibatch_size = 8
trunc = 0.5
use_trunc = True

# Choose two coordinates to interpolate between.
if use_trunc:
    z0 = nb_utils.truncated_z_sample(num_samples, G.dim_z, device=device, truncation=trunc)
    z1 = nb_utils.truncated_z_sample(num_samples, G.dim_z, device=device, truncation=trunc)
else:
    z0 = torch.randn(num_samples, G.dim_z).to(device)
    z1 = torch.randn(num_samples, G.dim_z).to(device)

# Interpolate between z0 and z1.
zs = nb_utils.interp(z0, z1, num_midpoints, device=device)
zs = zs.view(-1, zs.size(-1))

# Choose a random class for each row of interpolations.
ys = torch.cat([torch.ones(num_midpoints + 2).long() * torch.randint(G.n_classes, (1,))
                for _ in range(num_samples)]).long().to(device)

with torch.no_grad():
    samples = utils.elastic_gan(g, zs, ys)

# Show
nb_utils.imshow(255* np.transpose(torchvision.utils.make_grid(samples.cpu(), nrow=num_midpoints + 2, normalize=True).numpy(), (1, 2, 0)))

In [None]:
# Class-wise interpolation
# Inter-class (y only) Latent space interpolation
num_samples = 32
num_midpoints = 8
minibatch_size = 8

dev = next(G.parameters()).device
x0 = torch.randn(num_samples, G.dim_z).to(dev)
x1 = torch.randn(num_samples, G.dim_z).to(dev)
zs = nb_utils.interp(x0, x1, num_midpoints, device=dev)
zs = zs.view(-1, zs.size(-1))

class_a = G.shared(torch.ones(num_samples, device=dev).long() * torch.randint(G.n_classes, (1,), device=dev))
class_b = G.shared(torch.ones(num_samples, device=dev).long() * torch.randint(G.n_classes, (1,), device=dev))
ys = nb_utils.interp(class_a, class_b, num_midpoints, device=dev)
ys = ys.view(-1, ys.size(-1))

with torch.no_grad():
    # Split batches into mini-batches so that it fits in memory.
    samples = torch.cat([G(z, y) for z, y in zip(zs.split(minibatch_size), ys.split(minibatch_size))])

# vutils.visualize_samples(samples, nrow=num_midpoints + 2, figsize=(60, 60))
nb_utils.imshow(255* np.transpose(torchvision.utils.make_grid(samples.cpu(), nrow=num_midpoints + 2, normalize=True).numpy(), (1, 2, 0)))