In [None]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
import time
from torch.autograd import Variable
from dataset import get_dataloader
import numpy as np

In [None]:
from dcgan import NetG, NetD

In [None]:
ngpu = 1
nz = 100
ngf = 64
nc = 3
batch_size = 1
cuda = True
z_distribution = 'uniform'

torch.cuda.set_device(0)

# load netG
netG = NetG(ngpu, nz, ngf, nc)
netG.load_state_dict(torch.load('dcgan_out/netG_epoch_10.pth'))

In [None]:
# generate ground-truth noise
if z_distribution == 'uniform':
    noise = torch.FloatTensor(batch_size, nz, 1, 1).uniform_(-1, 1)
elif z_distribution == 'normal':
    noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
else:
    raise ValueError()
noise = Variable(noise)
noise.data.resize_(batch_size, nz, 1, 1)
noise.data.normal_(0, 1)

for param in netG.parameters():
    param.requires_grad = False

In [None]:
# fix fake, and try to find noise_approx
if z_distribution == 'uniform':
    noise_approx = torch.FloatTensor(batch_size, nz, 1, 1).uniform_(-1, 1)
elif z_distribution == 'normal':
    noise_approx = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
else:
    raise ValueError()

mse_loss = nn.MSELoss()
mse_loss_ = nn.MSELoss()

if cuda:
    netG.cuda()
    mse_loss.cuda()
    mse_loss_.cuda()
    noise, noise_approx = noise.cuda(), noise_approx.cuda()
    
noise_approx = Variable(noise_approx)
noise_approx.requires_grad = True
    
# generate ground-truth fake
fake = netG(noise)

optimizer_approx = optim.Adam([noise_approx], lr=0.01)

for i in range(100000):
    fake_approx = netG(noise_approx)
    mse_g_z = mse_loss(fake_approx, fake)
    mse_z = mse_loss_(noise_approx, noise)
    if i % 100 == 0:
        print("[Iter {}] MSE_FAKE: {}, MSE_Z: {}".format(i, mse_g_z.data[0], mse_z.data[0]))
    
    optimizer_approx.zero_grad()
    mse_g_z.backward()
    optimizer_approx.step()

fake = netG(noise)
fake_approx = netG(noise_approx)

In [None]:
vutils.save_image(fake.data, 'fake.png', normalize=True)
vutils.save_image(fake_approx.data, 'fake_approx.png', normalize=True)

In [None]:
noise.cpu().data.numpy().squeeze()

In [None]:
noise_approx.cpu().data.numpy().squeeze()