## Import

In [1]:
import torch
import torch.nn as nn

from torchvision.datasets import MNIST
from torchvision.utils import save_image
import torchvision.transforms as transforms

from torch.autograd import Variable
from torch.utils.data import DataLoader

## Dataset & DataLoader

### Transforms

In [2]:
transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize(mean=0.5, std=0.5)
])

In [12]:
def denorm(x):
  # 平均0.5 三個標準差 為 -1, 2
  out = (x+1)/2
  return out.clamp(0, 1)

### Dataset - MNIST

In [3]:
train_dataset = MNIST(root = './', train=True, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### DataLoader

In [26]:
data_loader = DataLoader(
        dataset = train_dataset, 
        batch_size = 64, 
        shuffle = False
)

## Models

* Discriminator
* Generator

### Discriminator

In [27]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.discriminator(x)

### Generator

In [36]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            nn.Linear(64, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, x):
        return self.generator(x)

In [37]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

discriminator = Discriminator().to(device)
generator = Generator().to(device)

## Loss Function & Optimizer

In [38]:
# loss function
loss_func = nn.BCELoss()

# optimizer
d_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
g_opt = torch.optim.Adam(generator.parameters(), lr=0.0001)

## Training Model

In [40]:
for epoch in range(300):
    for idx, (images, _) in enumerate(data_loader):

        batch_size = images.size(0)
        real_images = Variable(images.view(batch_size, -1)).to(device)
    
        # feed real images to discriminator
        real_outputs = discriminator(real_images)
        real_labels = torch.ones(batch_size, 1).to(device)
        
        # generate fake images 
        noise = (torch.rand(batch_size, 64)).to(device)
        fake_images = generator(noise)
        # feed fake images to discriminator
        fake_outputs = discriminator(fake_images)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_labels, fake_labels), 0)

        # Zero the parameter gradients
        # Backward propagation
        d_opt.zero_grad()
        d_loss = loss_func(outputs, targets)
        d_loss.backward()
        d_opt.step()

        # generate fake images
        noise = (torch.rand(batch_size, 64)).to(device)
        fake_images = generator(noise)
        fake_outputs = discriminator(fake_images)
        targets = torch.ones(batch_size, 1).to(device)
        g_loss = loss_func(fake_outputs, targets)
        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

        if (idx+1)%300 == 0:
            print('Epoch %d, batch %d, d_loss %.4f g_loss %.4f'
            %(epoch, idx+1, d_loss.data, g_loss.data))
        
        if epoch == 0:
            images = images.view(images.size(0), 1, 28, 28)
            save_image(denorm(images), './data/real_images.png')
        
    # save fake image
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) # c, w, h
    save_image(denorm(fake_images), './data/fake_images-%d.png'%(epoch+1))

Epoch 0, batch 300, d_loss 0.1410 g_loss 1.9030
Epoch 0, batch 600, d_loss 0.5441 g_loss 0.6985
Epoch 0, batch 900, d_loss 0.4907 g_loss 0.7557
Epoch 1, batch 300, d_loss 0.5835 g_loss 0.7393
Epoch 1, batch 600, d_loss 0.6794 g_loss 0.6926
Epoch 1, batch 900, d_loss 0.6712 g_loss 0.7256
Epoch 2, batch 300, d_loss 0.7098 g_loss 0.7280
Epoch 2, batch 600, d_loss 0.5228 g_loss 0.8896
Epoch 2, batch 900, d_loss 0.5813 g_loss 0.8695
Epoch 3, batch 300, d_loss 0.7585 g_loss 0.6395
Epoch 3, batch 600, d_loss 0.5865 g_loss 0.7970
Epoch 3, batch 900, d_loss 0.6527 g_loss 0.7741
Epoch 4, batch 300, d_loss 0.4932 g_loss 0.9478
Epoch 4, batch 600, d_loss 0.5905 g_loss 0.8568
Epoch 4, batch 900, d_loss 0.4868 g_loss 0.9856
Epoch 5, batch 300, d_loss 0.4224 g_loss 1.0539
Epoch 5, batch 600, d_loss 0.5538 g_loss 0.7413
Epoch 5, batch 900, d_loss 0.4540 g_loss 1.0465
Epoch 6, batch 300, d_loss 0.8604 g_loss 0.4979
Epoch 6, batch 600, d_loss 0.6576 g_loss 0.7903
Epoch 6, batch 900, d_loss 0.7907 g_loss