<a href="https://colab.research.google.com/github/KeisukeShimokawa/papers-challenge/blob/master/src/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
!nvidia-smi

Mon May  4 08:51:38 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    26W / 250W |     10MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [0]:
%load_ext tensorboard

In [0]:
import yaml
import dataclasses
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision import transforms
from torchsummary import summary

In [42]:
!pip freeze | grep torch

torch==1.5.0+cu101
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.6.0+cu101


In [0]:
class Config:
    in_dim = 784
    ndf = 64
    height = 28
    width = 28
    channels = 1
    zdim = 100
    cuda = True
    n_epochs = 200
    bs = 128
    lr = 1e-4
    b1 = 0.5
    b2 = 0.999

    img_shape = (channels, height, width)

In [68]:
if Config.cuda and torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
device

device(type='cuda', index=0)

## Model

In [0]:
class Generator(nn.Module):

    def __init__(self, in_dim, ndf, out_shape):
        super(Generator, self).__init__()
        self.out_shape = out_shape

        def block (in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.ReLU(inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_dim, ndf, normalize=False),
            *block(ndf,    ndf*2),
            *block(ndf*2,  ndf*4),
            *block(ndf*4,  ndf*8),
            nn.Linear(ndf*8, int(np.prod(out_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.out_shape)
        return img

In [0]:
class Discriminator(nn.Module):
    
    def __init__(self, out_shape, ndf):
        super(Discriminator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(int(np.prod(out_shape)), ndf*8, normalize=False),
            *block(ndf*8, ndf*4),
            *block(ndf*4, ndf*2),
            *block(ndf*2, ndf  ),
            nn.Linear(ndf, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [0]:
generator = Generator(Config.in_dim, Config.ndf, Config.img_shape)
discriminator = Discriminator(Config.img_shape, Config.ndf)

if Config.cuda:
    generator.to(device)
    discriminator.to(device)

In [74]:
summary(generator, (int(np.prod(Config.img_shape)),))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 64]          50,240
              ReLU-2                   [-1, 64]               0
            Linear-3                  [-1, 128]           8,320
       BatchNorm1d-4                  [-1, 128]             256
              ReLU-5                  [-1, 128]               0
            Linear-6                  [-1, 256]          33,024
       BatchNorm1d-7                  [-1, 256]             512
              ReLU-8                  [-1, 256]               0
            Linear-9                  [-1, 512]         131,584
      BatchNorm1d-10                  [-1, 512]           1,024
             ReLU-11                  [-1, 512]               0
           Linear-12                  [-1, 784]         402,192
             Tanh-13                  [-1, 784]               0
Total params: 627,152
Trainable params:

In [75]:
summary(discriminator, Config.img_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]         401,920
         LeakyReLU-2                  [-1, 512]               0
            Linear-3                  [-1, 256]         131,328
       BatchNorm1d-4                  [-1, 256]             512
         LeakyReLU-5                  [-1, 256]               0
            Linear-6                  [-1, 128]          32,896
       BatchNorm1d-7                  [-1, 128]             256
         LeakyReLU-8                  [-1, 128]               0
            Linear-9                   [-1, 64]           8,256
      BatchNorm1d-10                   [-1, 64]             128
        LeakyReLU-11                   [-1, 64]               0
           Linear-12                    [-1, 1]              65
          Sigmoid-13                    [-1, 1]               0
Total params: 575,361
Trainable params:

In [0]:
criterion = torch.nn.BCELoss()
optimizerG = torch.optim.Adam(generator.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))

In [78]:
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((Config.height, Config.width)), 
             transforms.ToTensor(), 
             transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=Config.bs,
    shuffle=True,
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw
Processing...
Done!




In [81]:
x, y = next(iter(dataloader))
print(x.shape)
print(y.shape)
print(y[0:10])

torch.Size([128, 1, 28, 28])
torch.Size([128])
tensor([3, 7, 4, 1, 0, 4, 6, 5, 6, 9])


In [87]:
for epoch in range(Config.n_epochs):
    for i, (imgs, _) in tqdm(enumerate(dataloader)):

        real_label = torch.ones(imgs.size(0), 1, dtype=torch.float32).to(device)
        fake_label = torch.zeros(imgs.size(0), 0, dtype=torch.float32).to(device)

        real_imgs = imgs.to(device)

        # -----------------
        #  Train Generator
        # -----------------

        optimizerG.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

469it [00:09, 51.49it/s]
346it [00:06, 51.67it/s]

KeyboardInterrupt: ignored