In [None]:
# reference: https://github.com/pytorch/examples/blob/master/dcgan/main.py
#
import os
import itertools
import argparse
from tqdm import tqdm
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv

import torch
import torch.nn as nn
import torchvision 
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter

from dcgan import Generator, Discriminator, weights_initialization, train

In [None]:
nz = 100
nfg = 128
nfd = 64
nc = 1
model_name = 'dcgan_seed=1'
model_name = f'{model_name}_{datetime.now().strftime("%Y.%m.%d-%H:%M:%S")}'
data_root = '../data'
figure_root = os.path.join('./figures', model_name)
model_root = os.path.join('./models', model_name)
log_root = os.path.join('./logs', model_name)
load_weights_generator = 'models/dcgan_seed=1/G_epoch_29.pt'
load_weights_discriminator = 'models/dcgan_seed=1/D_epoch_29.pt'
image_size = 64
batch_size = 64
lr = 0.0002
beta1 = 0.5
n_epochs = 10
n_batches_print = 100
seed = 1
n_workers = 8
gpu_id = 0

In [None]:
def plot_one(x, color_bar=False):
    x = x.detach().cpu().numpy().transpose((1,2,0)).squeeze()
    plt.imshow(x)
    plt.axis('off')
    if color_bar:
        plt.colorbar(extend='both')
    return plt    

In [None]:
trainset = datasets.MNIST(root=data_root, download=True,
                   transform=transforms.Compose([
                       transforms.Resize(image_size),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,)),
                   ]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=n_workers)

In [None]:
x = next(iter(trainloader))
plot_one(x[0][0])
torch.mean(x[0]), torch.std(x[0])

In [None]:
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = Generator(nz, nfg, nc).to(device)
G.apply(weights_initialization)
if load_weights_generator != '':
    G.load_state_dict(torch.load(load_weights_generator))
    
D = Discriminator(nc, nfd).to(device)
D.apply(weights_initialization)
if load_weights_discriminator != '':
    D.load_state_dict(torch.load(load_weights_discriminator))

## Manipulating Generator Representation

##### vary a single dimension of latent vector

In [None]:
def single_dim_varying_z(n_samples, z_range=(-1,1), z_init=torch.zeros((nz, nz, 1, 1), device=device)):
    for v in np.linspace(*z_range, n_samples):
        z = z_init.clone()
        for i in range(nz):
            z[i,i,:,:] += v
        yield z

def all_dim_varying_z(n_samples, z_range=(-1,1)):
    z_init=torch.randn((nz, nz, 1, 1), device=device)
    for v in np.linspace(*z_range, n_samples):
        z = z_init.clone()
        for i in range(nz):
            z[i,:,:,:] += v
        yield z

In [None]:
out_dir1 = './figures/single_dim_varying_z'
os.makedirs(out_dir1, exist_ok=True)
out_dir2 = './figures/single_dim_varying_z_labeled'
os.makedirs(out_dir2, exist_ok=True)
    
for i, z in enumerate(single_dim_varying_z(100, z_range=(-0.05,0.05))):
    x = G(z)
    
    torchvision.utils.save_image(x, os.path.join(out_dir1,f'{i}.png'), normalize=True, nrow=10)
    
    x = torchvision.utils.make_grid(x, normalize=True, nrow=10)
    x = x.detach().cpu().numpy().transpose((1,2,0)).squeeze()
    plt.imshow(x)
    plt.text(0, 0, f'{i}', horizontalalignment='left', verticalalignment='bottom', color='red')
    plt.axis('off')
    plt.savefig(os.path.join(out_dir2,f'{i}.png'))
    plt.close()

In [None]:
out_dir1 = './figures/all_dim_varying_z'
os.makedirs(out_dir1, exist_ok=True)
out_dir2 = './figures/all_dim_varying_z_labeled'
os.makedirs(out_dir2, exist_ok=True)
    
for i, z in enumerate(all_dim_varying_z(100)):
    x = G(z)
    
    torchvision.utils.save_image(x, os.path.join(out_dir1,f'{i}.png'), normalize=True, nrow=10)
    
    x = torchvision.utils.make_grid(x, normalize=True, nrow=10)
    x = x.detach().cpu().numpy().transpose((1,2,0)).squeeze()
    plt.imshow(x)
    plt.text(0, 0, f'{i}', horizontalalignment='left', verticalalignment='bottom', color='red')
    plt.axis('off')
    plt.savefig(os.path.join(out_dir2,f'{i}.png'))
    plt.close()

#### given 2 latents, interpolate between the two, visualize in model distribution space

In [None]:
def interpolate(z1, z2, alpha):
    """ Interpolate between each latent vector in minibatch
        z1,z2    (N,nz,1,1)
    """
    return (1-alpha)*z1 + alpha*z2

z1 = torch.tensor([[1,2],[3,4]]).view(2,2)
z2 = torch.tensor([[1,2],[3,4]]).view(2,2)*10

interpolate(z1,z2,0.5)

In [None]:
out_dir = './figures/interpolate'
os.makedirs(out_dir, exist_ok=True)

n_samples = 100

z1 = torch.randn((batch_size, nz, 1, 1), device=device)
z2 = torch.randn((batch_size, nz, 1, 1), device=device)

for i, alpha in enumerate(np.linspace(-0.5,1.5,n_samples)):
    z = interpolate(z1,z2,alpha)
    x = G(z)
    torchvision.utils.save_image(x, os.path.join(out_dir,f'{i}.png'), normalize=True)

#### vector arithmetics

z1 - z2 + z3

+ z1: class=3, slanted
+ z2: class=3, straight
+ z3: class=6, arbitrary

In [None]:
device = torch.device('cpu')
G, D = G.to(device), D.to(device)

In [None]:
zz = torch.randn((64, nz, 1, 1)) + 1
xx = G(zz)

plot_one(torchvision.utils.make_grid(xx,normalize=True)).show()

In [None]:
x1 = xx[[12,29,32,55,57]]
z1 = zz[[12,29,32,55,57]]
plot_one(torchvision.utils.make_grid(x1,normalize=True)).show()

In [None]:
zzz = torch.randn((100,nz,1,1))
xxx = G(zzz)
plot_one(torchvision.utils.make_grid(xxx,normalize=True,nrow=10)).show()

In [None]:
x2 = xxx[[32,89,49,88,38]]
z2 = zzz[[12,29,32,55,57]]
plot_one(torchvision.utils.make_grid(x2,normalize=True)).show()

In [None]:
x3 = xxx[[0,1,27,44,77]]
z3 = zzz[[0,1,27,44,77]]
plot_one(torchvision.utils.make_grid(x3,normalize=True)).show()

In [None]:
zout = (z1-z2+z3).mean(0).unsqueeze(0)
z1.mean(), z2.mean(), z3.mean(), zout.mean()

In [None]:
xout = G(zout)
plot_one(torchvision.utils.make_grid(xout,normalize=True)).show()

xxout = G(z1-z2+z3)
plot_one(torchvision.utils.make_grid(xxout,normalize=True)).show()

In [None]:
im = torch.cat((x1,x1.mean(0).unsqueeze(0),x2,x2.mean(0).unsqueeze(0),x3,x3.mean(0).unsqueeze(0),xxout,xout),dim=0)
plot_one(torchvision.utils.make_grid(im,normalize=True,nrow=6)).show()
torchvision.utils.save_image(im,'gifs/arithmetics.png',normalize=True,nrow=6)