# DCGAN for MNIST (PyTorch)

arXiv: https://arxiv.org/abs/1511.06434

<img src="https://lh3.googleusercontent.com/M0D8jzjxyELjUFLVyZYF0UQVuYyLywwIorNRhyOwSeOsATqK_tOZS3cHAnHZSP5I-fmjSCFGB5Ztrxmd7SJpY2amIXOox7H_OVZhGZMqdGk5yUcrDK4Qh5941rMLfW_N67eNAyetS0lB6LlAShsdtqGJCfBrW5a7EBXJnmpM0v640_5-xFDgsxIaC5IliQqSxoz5cYhK14d3tG2V_Qa39CZ4yKlQNnZb0OhWRTezWoFQ3uZdK2RQHQ22PA3CrEW27kTVl95688fhA1EKvZloEZ7YZolxErdsx9Lc3zcE8kiI0gmqErKSsOOImwnL9ESC6rPosze4CSvs8-bSUOLXMyO8fQIDvObuRFfYT6TAflAYSp8PXCevA2LhvUCFat5YV5N6XpFmKoZam6w-lX8koEEeZI6djv8z5eETO3ab1Fhn7_UB1hper6kScrtj_4CRdeP_mvuALhRQvEB81_jV6N5KIFTJ2b7oCoo8to32XZvmk42w6IR8WasMtOQJB3yeKWY5mGAtoLOpuJ4H7r0KAXxbaHpDvltlIR227Ria2-7FSfSJTnaA3bv1MjHpLALaNnelZAwySJynzP0lof8kz9EiQCsjSL2TN-pudLC7jWHnyy-3FAKuj6sK=w807-h208-no">

<img src="https://lh3.googleusercontent.com/CqWKmMExGj3Ws9_vYROf2hnnV-HH5hlzvga0Dv_NrHXZKTwoSFWNZ4GO3ee9dVar-kCZmNkpLxkTZ9dLQj9uKeY__ys608i7jfWQqoIRrF2NU-ESZBiWbSxqeEwOICy-W-oWYDAK9Cxse4JPOoWOxFrfaMaNg-cC7UsJCa-PTtt8G11wENorShWIaa2_2dAwn92kaPj8bABqKLRuDEcOxwLdZpnGQOWZPRNgVprefoStGMnfpyRXb9fx-yoBhgSExMv-ybPZLbbymZ8PdMY8bCiNZZ0MRvjvRe-hYGJ_4fWwjmNRaB2ixKLjaaM0Ha2xLA3G_zbpfEo9ygLGIufntEWX6M3mXz-ctK3QrninW3Bfia1IYNQFKJUVnaGvZf6-bFswghn-ody073j89m6fYYysYrVzxxE9HQJeKGuwQKBkG1e7mG1PPwnr-hovxLiHRduMmi0BjNYxOtSYkffBPYzR9n8h66qC-AepoKAGO-Z7LBAuBKpwflhih1Ho9ttBRxBYTwkVeYOMIV8aQkNNY9CKbC-VazHnN7x-eWV0us2tDrJCOcYU3wjBMi-76oalxldFxQQ9MQy1x5SWkiI6vEW_hCbaAlWGJB3bud91PbciIuTk-LIH9Nh4=w805-h350-no">

Deep Convolution GANに以下の改善を行う。
- すべてのプーリングレイヤを strided convolutions(discriminator)と fractional-stirided convolutions(generator)に変更する。
- generator と discriminator に batchnormを使う。
- 全結合隠れ層を取り除く。
- ReLU 活性関数を generatorで使う。ただし、output層は tanhを使う。
- LeakyReLU活性関数をdiscriminatorのすべての層で使う。

もとい！

公式チュートリアルにサンプルコードが公開されているので、それを参考に実装する。
- [examples/dcgan at master · pytorch/examples](https://github.com/pytorch/examples/tree/master/dcgan)

In [1]:
% matplotlib inline
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

if torch.cuda.is_available():
    import torch.cuda as t
else:
    import torch as t

from torchvision import datasets, models, transforms, utils
import torchvision.utils as vutils

import numpy as np
from numpy.random import normal
import matplotlib.pyplot as plt
import os

ModuleNotFoundError: No module named 'torch'

## mnist datasetの準備

In [67]:
bs = 100
sz = 64

In [68]:
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Scale(sz),
                       transforms.ToTensor()
                   ])),
    batch_size=bs
)

## Model

In [69]:
nz = 100
ngf = 64
ndf = 64
nc = 1

In [78]:
'''Discriminater'''
class netD(nn.Module):
    def __init__(self):
        super(netD, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        #x = x.view(x.size(0), sz*sz)
        x = self.main(x)
        return x

'''Generator'''
class netG(nn.Module):
    def __init__(self):
        super(netG, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, x):
        # x = x.view(bs,100)
        x = self.main(x)
        x = x.view(-1, 1, sz, sz)
        return x

In [79]:
criteion = nn.BCELoss()
net_D = netD()
net_G = netG()

if torch.cuda.is_available():
    D = net_D.cuda()
    G = net_G.cuda()
    criteion = criteion.cuda()    

In [80]:
optimizerD = optim.Adam(net_D.parameters(), lr = 0.00005)
optimizerG = optim.Adam(net_G.parameters(), lr = 0.00005)

## Train

In [81]:
input = t.FloatTensor(bs, 1, sz, sz)
noise = t.FloatTensor(normal(0, 1,(bs, 100, 1, 1)))
fixed_noise = t.FloatTensor(bs, 100, 1, 1).normal_(0, 1)
label = t.FloatTensor(bs)

real_label = 1
fake_label = 0

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

In [82]:
niter = 4000

In [83]:
for epoch in range(niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real (data)
        net_D.zero_grad()
        real, _ = data
        input.data.resize_(real.size()).copy_(real)
        label.data.resize_(bs).fill_(real_label)

        output = net_D(input)
        errD_real = criteion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        #train with fake (generated)
        noise.data.resize_(bs, 100, 1, 1)
        noise.data.normal_(0, 1)
        fake = net_G(noise)
        label.data.fill_(fake_label)
        output = net_D(fake.detach())
        errD_fake = criteion(output, label)
        errD_fake.backward()
        D_G_z1 = output.data.mean()

        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        net_G.zero_grad()
        label.data.fill_(real_label)
        output = net_D(fake)
        errG = criteion(output, label)
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()
        if i % 100 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                 % (epoch, niter, i, len(dataloader),
                   errD.data[0], errG.data[0],  D_x, D_G_z1, D_G_z2))
    if epoch % 10 == 0:
        fake = net_G(fixed_noise)
        vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png'
                              % ('results', epoch),normalize=True)

[0/4000][0/600] Loss_D: 1.3842 Loss_G: 0.7827 D(x): 0.4847 D(G(z)): 0.4785 / 0.4597


KeyboardInterrupt: 

In [None]:
fake = net_G(fixed_noise)
vutils.save_image(fake.data[:64], '%s/fake_samples2.png' % 'results' ,normalize=True)

In [None]:
from PIL import Image
im = Image.open("results/fake_samples2.png", "r")
plt.imshow(np.array(im))