In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
import time
from torch.autograd import Variable
from dataset import get_dataloader
import numpy as np

In [2]:
from dcgan import NetG, NetD

In [3]:
ngpu = 1
nz = 100
ngf = 64
nc = 3
batch_size = 1
cuda = True

# load netG
netG = NetG(ngpu, nz, ngf, nc)
netG.load_state_dict(torch.load('dcgan_out/netG_epoch_10.pth'))

In [4]:
# generate ground-truth noise
noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
noise = Variable(noise)
noise.data.resize_(batch_size, nz, 1, 1)
noise.data.normal_(0, 1)

for param in netG.parameters():
    param.requires_grad = False

In [7]:
# fix fake, and try to find noise_approx
noise_approx = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)

mse_loss = nn.MSELoss()
mse_loss_ = nn.MSELoss()

if cuda:
    netG.cuda()
    mse_loss.cuda()
    mse_loss_.cuda()
    noise, noise_approx = noise.cuda(), noise_approx.cuda()
    
noise_approx = Variable(noise_approx)
noise_approx.requires_grad = True
    
# generate ground-truth fake
fake = netG(noise)

optimizer_approx = optim.Adam([noise_approx], lr=0.1)

for i in range(10000):
    fake_approx = netG(noise_approx)
    mse_g_z = mse_loss(fake_approx, fake)
    mse_z = mse_loss_(noise_approx, noise)
    if i % 100 == 0:
        print("[Iter {}] MSE_FAKE: {}, MSE_Z: {}".format(i, mse_g_z.data[0], mse_z.data[0]))
    
    optimizer_approx.zero_grad()
    mse_g_z.backward()
    optimizer_approx.step()

[Iter 0] MSE_FAKE: 0.20300279557704926, MSE_Z: 2.026460647583008
[Iter 100] MSE_FAKE: 0.021984094753861427, MSE_Z: 2.971254348754883
[Iter 200] MSE_FAKE: 0.014727490954101086, MSE_Z: 2.8681323528289795
[Iter 300] MSE_FAKE: 0.006044231355190277, MSE_Z: 2.454547166824341
[Iter 400] MSE_FAKE: 0.0031996462494134903, MSE_Z: 1.9438693523406982
[Iter 500] MSE_FAKE: 0.0014635728439316154, MSE_Z: 1.3652663230895996
[Iter 600] MSE_FAKE: 0.0007068703998811543, MSE_Z: 0.9875016808509827
[Iter 700] MSE_FAKE: 0.0003372490464244038, MSE_Z: 0.7456369996070862
[Iter 800] MSE_FAKE: 0.0001678826374700293, MSE_Z: 0.611947238445282
[Iter 900] MSE_FAKE: 6.771142216166481e-05, MSE_Z: 0.527716338634491
[Iter 1000] MSE_FAKE: 2.557785228418652e-05, MSE_Z: 0.4840486943721771
[Iter 1100] MSE_FAKE: 9.290783964388538e-06, MSE_Z: 0.4624810814857483
[Iter 1200] MSE_FAKE: 3.103694552919478e-06, MSE_Z: 0.45244094729423523
[Iter 1300] MSE_FAKE: 9.336624771094648e-07, MSE_Z: 0.4478849172592163
[Iter 1400] MSE_FAKE: 2.719

fake = netG(noise)
fake_approx = netG(noise_approx)

In [8]:
vutils.save_image(fake.data, 'fake.png', normalize=True)
vutils.save_image(fake_approx.data, 'fake_approx.png', normalize=True)

In [15]:
noise.cpu().data.numpy().squeeze()

array([ 1.69100118,  0.72466552,  0.73702824,  2.3856442 , -0.94328904,
        0.05708402,  0.13811016,  0.77042025, -0.46943244,  1.24417961,
       -2.16498637, -1.42277205,  1.46144092,  0.17080203, -2.41030002,
       -0.40689799, -1.07115328,  0.418304  , -1.22909939,  0.34344131,
        3.14133811,  0.06180882,  1.3216548 ,  1.10382843,  1.38872552,
        0.57337695, -0.79826051,  0.28043532, -0.83162022, -0.36407673,
       -1.2889415 ,  2.24460101, -0.74751437,  0.78405386,  0.68708318,
       -0.61732829,  0.82146758,  1.19340098,  2.65454221,  0.12988228,
       -1.64468169,  0.99226892,  0.77125621,  0.49970794, -0.14034425,
        0.14686015,  0.16135581,  0.76836437, -0.28110358, -0.29089868,
       -1.00630951, -0.17442678,  0.41141662,  0.12794891,  0.16601108,
        0.0290424 ,  0.27796671,  0.62717307,  0.2483553 , -1.18546915,
       -1.55874419,  0.9953292 ,  0.42532387, -0.97265053, -0.27604049,
        1.63162339,  0.65108162,  0.32728192, -0.6402244 ,  1.03

In [16]:
noise_approx.cpu().data.numpy().squeeze()

array([ 3.83413363,  1.64004314,  1.72916567,  5.46390581, -2.1696012 ,
        0.10305223,  0.31953248,  1.76448631, -1.10417485,  2.82099509,
       -4.9724288 , -3.2646203 ,  3.3718214 ,  0.35646182, -5.50440025,
       -0.90741354, -2.43938327,  0.97123891, -2.78461576,  0.74874979,
        7.23465443,  0.15557227,  3.02705884,  2.541502  ,  3.14032292,
        1.30251169, -1.83116555,  0.64468175, -1.92389059, -0.82483053,
       -2.97059417,  5.12450314, -1.74269056,  1.8109107 ,  1.55665874,
       -1.42892706,  1.86208618,  2.74175882,  6.05569935,  0.31600872,
       -3.7654779 ,  2.29190397,  1.79309726,  1.14913845, -0.3139883 ,
        0.28735164,  0.36075822,  1.7362262 , -0.64193821, -0.65429568,
       -2.31632185, -0.41044536,  0.96744823,  0.30887654,  0.39183661,
        0.07242063,  0.63668644,  1.43162835,  0.54962742, -2.72324181,
       -3.5811348 ,  2.26040077,  0.96572745, -2.21005559, -0.57419413,
        3.76024342,  1.50881732,  0.73172683, -1.45098174,  2.41