# Vanilla GAN for MNIST (PyTorch)

In [1]:
% matplotlib inline
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

if torch.cuda.is_available():
    import torch.cuda as t
else:
    import torch as t

from torchvision import datasets, models, transforms, utils
import numpy as np
from numpy.random import uniform
import matplotlib.pyplot as plt
import os

## mnist datasetの準備

In [50]:
bs,nz = 100,100

In [51]:
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/mnist', train=True, download=False,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=bs
)

## Model

In [52]:
'''Discriminater'''
class netD(nn.Module):
    def __init__(self):
        super(netD, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 300),
            nn.ReLU(),
            nn.Linear(300, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        x = self.main(x)
        return x

'''Generator'''
class netG(nn.Module):
    def __init__(self):
        super(netG, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 200),
            nn.ReLU(),
            nn.Linear(200, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(bs,nz)
        x = self.main(x)
        return x

In [53]:
criteion = nn.BCELoss()
net_D = netD()
net_G = netG()

if torch.cuda.is_available():
    D = net_D.cuda()
    G = net_G.cuda()
    criteion = criteion.cuda()    

In [54]:
optimizerD = optim.Adam(net_D.parameters(), lr = 1e-4)
optimizerG = optim.Adam(net_G.parameters(), lr = 1e-4)

## Train

In [55]:
def plot_gen(G, n_ex=16):
    plot_multi(G.predict(noise(n_ex)).reshape(n_ex, 28,28), cmap='gray')

In [56]:
def noise(bs): return np.random.rand(bs,100)

In [57]:
input = t.FloatTensor(bs, 1, 28, 28)
noise = t.FloatTensor(uniform(0,1,(bs, 100, 1, 1)))
fixed_noise = t.FloatTensor(bs, 100, 1, 1).normal_(0, 1)
label = t.FloatTensor(bs)

real_label = 1
fake_label = 0

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

In [58]:
niter = 4000

In [59]:
for epoch in range(niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real (data)
        net_D.zero_grad()
        real, _ = data
        input.data.resize_(real.size()).copy_(real)
        label.data.resize_(bs).fill_(real_label)

        output = net_D(input)
        errD_real = criteion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        #train with fake (generated)
        noise.data.resize_(bs, 100, 1, 1)
        noise.data.normal_(0, 1)
        fake = net_G(noise)
        label.data.fill_(fake_label)
        output = net_D(fake.detach())
        errD_fake = criteion(output, label)
        errD_fake.backward()
        D_G_z1 = output.data.mean()

        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        net_G.zero_grad()
        label.data.fill_(real_label)
        output = net_D(fake)
        errG = criteion(output, label)
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()
        if i % 100 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                 % (epoch, niter, i, len(dataloader),
                   errD.data[0], errG.data[0],  D_x, D_G_z1, D_G_z2))
    if epoch % 10 == 0:
        fake = net_G(fixed_noise)
        vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png'
                              % ('output', epoch),normalize=True)

[0/4000][0/600] Loss_D: 1.3861 Loss_G: 0.7312 D(x): 0.4991 D(G(z)): 0.4990 / 0.4814
[0/4000][100/600] Loss_D: 0.4276 Loss_G: 2.4578 D(x): 0.7695 D(G(z)): 0.1361 / 0.1018
[0/4000][200/600] Loss_D: 0.4214 Loss_G: 2.1539 D(x): 0.8091 D(G(z)): 0.1721 / 0.1161
[0/4000][300/600] Loss_D: 0.2059 Loss_G: 3.3204 D(x): 0.8802 D(G(z)): 0.0600 / 0.0364
[0/4000][400/600] Loss_D: 0.0780 Loss_G: 3.9904 D(x): 0.9672 D(G(z)): 0.0423 / 0.0195
[0/4000][500/600] Loss_D: 0.0637 Loss_G: 4.3640 D(x): 0.9655 D(G(z)): 0.0244 / 0.0130


NameError: name 'vutils' is not defined