In [16]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader as DataLoader
import torch.nn as nn
import torch.optim as opt
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt

In [13]:
def to_var(x):
  if torch.cuda.is_available():
    x = x.cuda()
  return Variable(x)

In [21]:
# 超參數
num_epoch = 200
learning_rate = 0.0005

In [6]:
# define tansform
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=0.5, std=0.5) # std 標準差
]) # Compose input type: list

In [7]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [9]:
D = nn.Sequential(
    nn.Linear(28*28, 256), #input為圖片
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

In [10]:
G = nn.Sequential(
    nn.Linear(64, 256), #input custormizer 64 here
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 28*28),
    nn.Tanh() # 因為transform的時候有Normalize在+-1間(mean=0.5)
)

In [12]:
# GPU
if torch.cuda.is_available():
  D.cuda()
  G.cuda()

In [15]:
loss_func = nn.BCELoss() # binary cross entropy
d_opt = opt.Adam(D.parameters(), lr=learning_rate)
g_opt = opt.Adam(G.parameters(), lr=learning_rate)

In [19]:
def de_normalize(x):
  out = (x + 1)/2
  return out.clamp(0, 1) # 所有小於0的都等於0, 所有大於1的都等於1

In [None]:
for epoch in range(num_epoch):
  for i, (images, _) in enumerate(train_loader):
    batch_size = images.size(0) # images.size(0) 總共多少資料
    images = to_var(images)
    images = images.view(batch_size, -1)

    # 真實圖片的 LABEL 都為1
    real_labels = to_var(torch.ones(batch_size, 1)) # (row, col)
    # 假圖片的 LABEL 都為0
    fake_labels = to_var(torch.zeros(batch_size, 1)) # (row, col)

    outputs = D(images)
    d_loss_at_real = loss_func(outputs, real_labels)
    real_score = outputs

    z = to_var(torch.randn(batch_size, 64)) # 要餵給G的隨機產生的向量, 64為G的輸入向量
    fake_images = G(z) 
    outputs = D(fake_images) # or output = D(G(z))
    
    d_loss_at_fake = loss_func(outputs, fake_labels) # output和0的距離
    fake_score = outputs # fake 越高表示G騙過D

    d_loss = d_loss_at_real + d_loss_at_fake
    D.zero_grad()
    d_loss.backward()
    d_opt.step()

    # 需要重新產生fake image
    z = to_var(torch.randn(batch_size, 64)) # 要餵給G的隨機產生的向量, 64為G的輸入向量
    fake_images = G(z) 
    outputs = D(fake_images) # or output = D(G(z))

    g_loss = loss_func(outputs, real_labels) # output和1的距離
    D.zero_grad() # because D(G(z))
    G.zero_grad()
    g_loss.backward()
    g_opt.step()

    if (i+1)%300 == 0:
      print('Epoch[%d], Batch[%d], d_loss:%.4f, g_loss:%.4f, D(x): %.2f, D(G(x)): %.2f'%(epoch+1, i+1, d_loss.data, g_loss.data, real_score.mean(), fake_score.mean()))
  
  if (epoch == 0): # 第一次先存正常的圖
    images = images.view(images.size(0), 1, 28, 28)
    save_image(de_normalize(images.data), './data/real_images.png')

  fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
  save_image(de_normalize(fake_images.data), './data/fake_images_'+str(epoch+1)+'.png')


Epoch[1], Batch[300], d_loss:0.3041, g_loss:2.7757, D(x): 0.88, D(G(x)): 0.15
Epoch[1], Batch[600], d_loss:0.5936, g_loss:6.7802, D(x): 0.89, D(G(x)): 0.21
Epoch[2], Batch[300], d_loss:0.3842, g_loss:5.9143, D(x): 0.88, D(G(x)): 0.13
Epoch[2], Batch[600], d_loss:5.2407, g_loss:18.5732, D(x): 0.86, D(G(x)): 0.19
Epoch[3], Batch[300], d_loss:1.3538, g_loss:8.2342, D(x): 0.73, D(G(x)): 0.24
Epoch[3], Batch[600], d_loss:0.7886, g_loss:1.7929, D(x): 0.76, D(G(x)): 0.32
Epoch[4], Batch[300], d_loss:0.7515, g_loss:23.8499, D(x): 0.86, D(G(x)): 0.08
Epoch[4], Batch[600], d_loss:0.4190, g_loss:5.0527, D(x): 0.82, D(G(x)): 0.11
Epoch[5], Batch[300], d_loss:1.2832, g_loss:3.6098, D(x): 0.67, D(G(x)): 0.25
Epoch[5], Batch[600], d_loss:0.9079, g_loss:3.5050, D(x): 0.69, D(G(x)): 0.14
Epoch[6], Batch[300], d_loss:0.4775, g_loss:3.2977, D(x): 0.85, D(G(x)): 0.17
Epoch[6], Batch[600], d_loss:1.2422, g_loss:2.5830, D(x): 0.77, D(G(x)): 0.37
Epoch[7], Batch[300], d_loss:0.7057, g_loss:2.4286, D(x): 0.79