A notebook to experiment with decoding a generated image from a GAN model. Decoding here refers to the task of finding the noise vector $z$ for a GAN model $g$ such that $g(z)$ (generated image) is closest to a given image.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kbrgan
import kbrgan.kernel as kernel
import kbrgan.glo as glo
import kbrgan.gen as gen
import kbrgan.main as main
import kbrgan.plot as plot
import kbrgan.net.net as net
import kbrgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
# True to use GPU
use_cuda = True and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
torch.set_default_tensor_type(tensor_type)

# load option depends on whether GPU is used
load_options = {'map_location': lambda storage, loc: storage.cuda(0)} if use_cuda \
    else {'map_location': lambda storage, loc: storage} 


## Load an MNIST generator

In [None]:
import kbrgan.mnist.dcgan as mnist_dcgan
import kbrgan.plot as plot

# ptnt = gen.PTNoiseTransformer
ptnt_fname = 'ptnt_mnist_dcgan_ep40_bs64.pt'
ptnt_fpath = glo.prob_model_folder('mnist_dcgan', ptnt_fname)
g = net.SerializableModule.load(ptnt_fpath)
g

In [None]:
plot.show_torch_imgs(g.sample(4), normalize=True)

In [None]:
stats.describe(g.sample(1).cpu().numpy().reshape(-1))

The output range of the loaded generator is [-1,1]. Make it [0, 1] to match the MNIST data.

In [None]:
to01 = lambda x: util.linear_range_transform(x, (-1.0, 1.0), (0.0, 1.0))
g01 = gen.PTNTDecPostProcess(g, to01)

In [None]:
stats.describe(g01.sample(1).cpu().numpy().reshape(-1))

## $\arg\min_z \|g(z) - y\|_p^2$ where $y$ is an image.

In [None]:
# load MNIST data
mnist_folder = glo.data_file('mnist')
mnist_dataset = torchvision.datasets.MNIST(mnist_folder, train=False, 
                        transform=transforms.Compose([
                           transforms.ToTensor(),
#                            transforms.Normalize((0.1307,), (0.3081,))
                       ]))

In [None]:
xy = mnist_dataset[3]
x = xy[0]
stats.describe(x.reshape(-1))

In [None]:
import kbrgan.mnist.util as mnist_util

digit = 8
variation = 10
input_imgs = mnist_util.pt_sample_by_labels(mnist_dataset, [(digit, variation)])
input_img = input_imgs[variation-1]
# input_img = mnist_dataset[ind][0]
print('Conditioned image')
plot.show_torch_imgs(input_img)

In [None]:
squared_loss = lambda x, y: torch.sum( (x-y)**2 )
l1_loss = lambda x,y: torch.sum( torch.abs(x-y) )
opts = {
    'n_opt_iter': 800,
    'lr': 1e-2,
}
input_img = input_img.to(device)

# initialize z
z0 = g01.sample_noise(1)
z0 = z0.to(device)
z0.requires_grad = True

with torch.no_grad():
    y0 = g01(z0)
#     plot.show_torch_imgs(y0)

losses, Zs = gen.decode_generator(g01, z0, input_img, l1_loss, **opts)
# losses, Zs = gen.decode_generator(g01, z0, input_img, squared_loss, **opts)

Plot the recorded losses

In [None]:
n_opt_iter = opts['n_opt_iter']
plt.plot(losses, 'b-', label='Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss');

Plot the recorded noise vectors during the optimization (as images)

In [None]:
n_every = 20
Zs_toshow = Zs[0::n_every]
n_toshow = len(Zs_toshow)

# transform to get images
with torch.no_grad():
    Z_cat = torch.cat(Zs_toshow, dim=0).to(device)
    gen_toshow = g01.forward(Z_cat)
    

In [None]:
print('Optimized images. Every {} iterations.'.format(n_every))
plot.show_torch_imgs(gen_toshow, nrow=10, figsize=(10, 6))

In [None]:
print('Conditioned image')
plot.show_torch_imgs(input_img)