In [None]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader as DataLoader
import torch.nn as nn
import torch.optim as opt
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt

In [None]:
def to_var(x):
  if torch.cuda.is_available():
    x = x.cuda()
  return Variable(x)

In [None]:
# 超參數
num_epoch = 30
d_learning_rate = 0.001
g_learning_rate = 0.0001

In [None]:
# define tansform
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=0.5, std=0.5) # std 標準差
]) # Compose input type: list

In [None]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

In [None]:
class Generator(nn.Module):
  def __init__(self, latent_dim=100, batch_norm=True):
    super(Generator, self).__init__()
    self.latent_dim = latent_dim
    self.batch_norm = batch_norm
    self.linear_1 = nn.Linear(self.latent_dim, 256*7*7, bias=False) # 256channels
    self.batch_norm_1 = nn.BatchNorm1d(256*7*7) if batch_norm else None
    self.leaky_relu = nn.LeakyReLU(0.2)
    
    # ConvTranspose2d 放大圖片(DE-CONV)
    self.conv_1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=2, bias=False)
    self.batch_norm_2d_1 = nn.BatchNorm2d(128) if batch_norm else None
    self.conv_2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)
    self.batch_norm_2d_2 = nn.BatchNorm2d(64) if batch_norm else None
    self.conv_3 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) # output
    self.tanh = nn.Tanh()

  def forward(self, x):
    out = self.linear_1(x)
    if self.batch_norm:
      out = self.batch_norm_1(out)
    out = self.leaky_relu(out)
    # 1D to 2D
    out = out.view(-1, 256, 7, 7) # batch_size, channels, row, col
    out = self.conv_1(out)
    if self.batch_norm:
      out = self.batch_norm_2d_1(out)
    out = self.conv_2(out)
    if self.batch_norm:
      out = self.batch_norm_2d_2(out)

    out = self.conv_3(out)
    out = self.tanh(out)
    return out


In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=2, padding=2, bias=True)
    self.leaky_relu = nn.LeakyReLU(0.2)
    self.dropout_2d = nn.Dropout2d(p=0.3)
    self.conv_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2, bias=True)
    self.linear_1 = nn.Linear(128*7*7, 1, bias=True)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    out = self.conv_1(x)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = self.conv_2(out)
    # 2D to 1D
    out = out.view(-1, 128*7*7)
    out = self.linear_1(out)
    out = self.sigmoid(out)

    return out


In [None]:
G = Generator()
D = Discriminator()
# GPU|
if torch.cuda.is_available():
  D.cuda()
  G.cuda()

In [None]:
print(D)
print(G)

Discriminator(
  (conv_1): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (leaky_relu): LeakyReLU(negative_slope=0.2)
  (dropout_2d): Dropout2d(p=0.3, inplace=False)
  (conv_2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (linear_1): Linear(in_features=6272, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)
Generator(
  (linear_1): Linear(in_features=100, out_features=12544, bias=False)
  (batch_norm_1): BatchNorm1d(12544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (leaky_relu): LeakyReLU(negative_slope=0.2)
  (conv_1): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
  (batch_norm_2d_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch_norm_2d_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_3): 

In [None]:
loss_func = nn.BCELoss() # binary cross entropy
d_opt = opt.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999)) #betas (beta1, beta2) 
g_opt = opt.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [None]:
def de_normalize(x):
  out = (x + 1)/2
  return out.clamp(0, 1) # 所有小於0的都等於0, 所有大於1的都等於1

In [None]:
for epoch in range(num_epoch):
  for i, (images, _) in enumerate(train_loader):
    batch_size = images.size(0) # images.size(0) 總共多少資料
    images = to_var(images)

    # 真實圖片的 LABEL 都為1
    real_labels = to_var(torch.ones(batch_size, 1)) # (row, col)
    # 假圖片的 LABEL 都為0
    fake_labels = to_var(torch.zeros(batch_size, 1)) # (row, col)

    outputs = D(images)
    d_loss_at_real = loss_func(outputs, real_labels)
    real_score = outputs

    z = to_var(torch.randn(batch_size, 100)) # 要餵給G的隨機產生的向量, 100為G的輸入向量
    fake_images = G(z) 
    outputs = D(fake_images) # or output = D(G(z))
    
    d_loss_at_fake = loss_func(outputs, fake_labels) # output和0的距離
    fake_score = outputs # fake 越高表示G騙過D

    d_loss = d_loss_at_real + d_loss_at_fake # total loss
    D.zero_grad()
    d_loss.backward()
    d_opt.step()

    # 需要重新產生fake image
    z = to_var(torch.randn(batch_size, 100)) # 要餵給G的隨機產生的向量, 64為G的輸入向量
    fake_images = G(z) 
    outputs = D(fake_images) # or output = D(G(z))

    g_loss = loss_func(outputs, real_labels) # output和1的距離
    D.zero_grad() # because D(G(z))
    G.zero_grad()
    g_loss.backward()
    g_opt.step()

    if (i+1)%300 == 0:
      print('Epoch[%d], Batch[%d], d_loss:%.4f, g_loss:%.4f, D(x): %.2f, D(G(x)): %.2f'%(epoch+1, i+1, d_loss.data, g_loss.data, real_score.mean(), fake_score.mean()))
  
  if (epoch == 0): # 第一次先存正常的圖
    images = images.view(images.size(0), 1, 28, 28)
    save_image(de_normalize(images.data), './data/real_images.png')

  fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
  save_image(de_normalize(fake_images.data), './data/fake_images_'+str(epoch+1)+'.png')


Epoch[1], Batch[300], d_loss:0.0000, g_loss:50.7319, D(x): 1.00, D(G(x)): 0.00
Epoch[1], Batch[600], d_loss:0.0000, g_loss:47.4435, D(x): 1.00, D(G(x)): 0.00
Epoch[2], Batch[300], d_loss:1.0271, g_loss:1.7022, D(x): 0.61, D(G(x)): 0.29
Epoch[2], Batch[600], d_loss:1.2399, g_loss:0.9792, D(x): 0.48, D(G(x)): 0.28
Epoch[3], Batch[300], d_loss:0.9877, g_loss:1.7790, D(x): 0.72, D(G(x)): 0.44
Epoch[3], Batch[600], d_loss:1.0953, g_loss:1.1450, D(x): 0.53, D(G(x)): 0.25
Epoch[4], Batch[300], d_loss:1.0050, g_loss:1.5582, D(x): 0.65, D(G(x)): 0.33
Epoch[4], Batch[600], d_loss:0.7049, g_loss:1.3121, D(x): 0.69, D(G(x)): 0.19
Epoch[5], Batch[300], d_loss:0.8731, g_loss:3.4256, D(x): 0.87, D(G(x)): 0.41
Epoch[5], Batch[600], d_loss:0.6345, g_loss:2.5783, D(x): 0.76, D(G(x)): 0.15
Epoch[6], Batch[300], d_loss:0.5951, g_loss:3.0970, D(x): 0.84, D(G(x)): 0.24
Epoch[6], Batch[600], d_loss:0.5301, g_loss:2.7596, D(x): 0.86, D(G(x)): 0.19
Epoch[7], Batch[300], d_loss:0.4147, g_loss:2.4727, D(x): 0.87