In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.nn import functional as F
from torchvision.utils import save_image

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 100
# num_epochs = 5
batch_size = 100
image_dir = './image/gan'
model_path = "./model"

In [3]:
# Create a directory if not exists
if not os.path.exists(image_dir):
    os.makedirs(image_dir)
if not os.path.exists(model_path):
    os.makedirs(model_path)

# 数据加载

In [4]:
# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],               # 1 for greyscale channels
                                     std=[0.5])])

In [5]:
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data',
                                   train=True,
                                   transform=transform,
                                   download=True)

In [6]:
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

# 模型定义

## 判别器定义

In [7]:
class Discriminator(nn.Module):
    def __init__(self, image_size: int, hidden_size: int):
        super(Discriminator, self).__init__()
        self.linear1 = nn.Linear(image_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out = self.linear1(x)
        out = F.leaky_relu(out, negative_slope=0.2, inplace=True)
        out = self.linear2(out)
        out = F.leaky_relu(out, negative_slope=0.2, inplace=True)
        out = self.linear3(out)
        return F.sigmoid(out)

In [8]:
# Discriminator
D = Discriminator(image_size=image_size, hidden_size=hidden_size)

## 生成器定义

In [10]:
class Generator(nn.Module):
    def __init__(self, image_size: int, latent_size: int, hidden_size: int):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(latent_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, image_size)

    def forward(self, x):
        out = self.linear1(x)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.linear3(out)
        return F.tanh(out)

In [11]:
# Generator 
G = Generator(image_size=image_size, hidden_size=hidden_size, latent_size=latent_size)

In [12]:
# Device setting
D = D.to(device)
G = G.to(device)

# 模型训练

In [13]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [14]:
# 图像反归一化
def denorm(x):
    min_value = -1
    max_value = 1
    out = (x - min_value) / (max_value - min_value)
    # plt expects values in [0, 1]
    return out.clamp(0, 1)  

In [15]:
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [16]:
def train(D, G, data_loader, epoch):
    total_step = len(data_loader)
    # Create the labels which are later used as input for the BCE loss
    real_labels = torch.ones(batch_size, 1).to(device)
    fake_labels = torch.zeros(batch_size, 1).to(device)
    for i, (images, _) in enumerate(data_loader):
        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1 - y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        real_images = images.reshape(batch_size, -1).to(device)
        d_outputs_real = D(real_images)
        d_loss_real = criterion(d_outputs_real, real_labels)
        real_score = d_outputs_real
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        d_outputs_fake = D(fake_images)
        d_loss_fake = criterion(d_outputs_fake, fake_labels)
        fake_score = d_outputs_fake
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #
        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        d_outputs_fake = D(fake_images)
        # We train G to maximize log(D(G(z)) instead of minimizing log(1 - D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(d_outputs_fake, real_labels)
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        if (i + 1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    # Save real images
    if (epoch + 1) == 1:
        real_images = real_images.reshape(images.size(0), 1, 28, 28)
        real_images_name = 'real_images.png'
        save_image(denorm(real_images), os.path.join(image_dir, real_images_name))
        print('epoch {} real images {} saved'.format(epoch + 1, real_images_name))
    # Save fake images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    fake_images_name = 'fake_images-{}.png'.format(epoch+1)
    save_image(denorm(fake_images), os.path.join(image_dir, fake_images_name))
    print('epoch {} fake images {} saved'.format(epoch + 1, fake_images_name))

# 模型保存

In [17]:
def save_mode():
    # Save model
    torch.save(G.state_dict(), os.path.join(model_path, "Generator.pt"))
    torch.save(D.state_dict(), os.path.join(model_path, "Discriminator.pt"))
    print("models saved")

In [18]:
def main():
    # Start training
    for epoch in range(num_epochs):
        train(D, G, data_loader, epoch)
    save_mode()

In [19]:
if __name__ == '__main__':
    main()

Epoch [0/100], Step [200/600], d_loss: 0.0628, g_loss: 3.8714, D(x): 0.99, D(G(z)): 0.05
Epoch [0/100], Step [400/600], d_loss: 0.0669, g_loss: 5.8127, D(x): 0.98, D(G(z)): 0.05
Epoch [0/100], Step [600/600], d_loss: 0.0421, g_loss: 5.6172, D(x): 0.98, D(G(z)): 0.02
epoch 1 real images real_images.png saved
epoch 1 fake images fake_images-1.png saved
Epoch [1/100], Step [200/600], d_loss: 0.0446, g_loss: 4.9776, D(x): 1.00, D(G(z)): 0.04
Epoch [1/100], Step [400/600], d_loss: 0.0505, g_loss: 5.3517, D(x): 0.97, D(G(z)): 0.02
Epoch [1/100], Step [600/600], d_loss: 0.5368, g_loss: 3.8375, D(x): 0.84, D(G(z)): 0.16
epoch 2 fake images fake_images-2.png saved
Epoch [2/100], Step [200/600], d_loss: 0.1171, g_loss: 3.8993, D(x): 0.96, D(G(z)): 0.05
Epoch [2/100], Step [400/600], d_loss: 0.2336, g_loss: 5.0401, D(x): 0.91, D(G(z)): 0.05
Epoch [2/100], Step [600/600], d_loss: 0.3689, g_loss: 4.0837, D(x): 0.86, D(G(z)): 0.07
epoch 3 fake images fake_images-3.png saved
Epoch [3/100], Step [200/

Epoch [61/100], Step [400/600], d_loss: 0.6660, g_loss: 1.5530, D(x): 0.87, D(G(z)): 0.33
Epoch [61/100], Step [600/600], d_loss: 0.5907, g_loss: 2.4484, D(x): 0.84, D(G(z)): 0.21
Epoch [62/100], Step [200/600], d_loss: 0.3851, g_loss: 2.8583, D(x): 0.91, D(G(z)): 0.19
Epoch [62/100], Step [400/600], d_loss: 0.4070, g_loss: 2.8479, D(x): 0.85, D(G(z)): 0.14
Epoch [62/100], Step [600/600], d_loss: 0.7081, g_loss: 3.2311, D(x): 0.93, D(G(z)): 0.33
Epoch [63/100], Step [200/600], d_loss: 0.5276, g_loss: 2.6238, D(x): 0.81, D(G(z)): 0.14
Epoch [63/100], Step [400/600], d_loss: 0.5914, g_loss: 3.0684, D(x): 0.83, D(G(z)): 0.24
Epoch [63/100], Step [600/600], d_loss: 0.5139, g_loss: 2.2964, D(x): 0.84, D(G(z)): 0.19
Epoch [64/100], Step [200/600], d_loss: 0.6675, g_loss: 2.5774, D(x): 0.75, D(G(z)): 0.14
Epoch [64/100], Step [400/600], d_loss: 0.7898, g_loss: 2.6074, D(x): 0.70, D(G(z)): 0.17
Epoch [64/100], Step [600/600], d_loss: 0.7838, g_loss: 2.6269, D(x): 0.75, D(G(z)): 0.17
Epoch [65/

Epoch [92/100], Step [200/600], d_loss: 0.8001, g_loss: 1.6516, D(x): 0.71, D(G(z)): 0.25
Epoch [92/100], Step [400/600], d_loss: 0.6889, g_loss: 2.2359, D(x): 0.77, D(G(z)): 0.24
Epoch [92/100], Step [600/600], d_loss: 0.6726, g_loss: 1.6754, D(x): 0.81, D(G(z)): 0.27
Epoch [93/100], Step [200/600], d_loss: 0.7507, g_loss: 2.0496, D(x): 0.71, D(G(z)): 0.19
Epoch [93/100], Step [400/600], d_loss: 0.7378, g_loss: 2.1993, D(x): 0.77, D(G(z)): 0.25
Epoch [93/100], Step [600/600], d_loss: 0.8526, g_loss: 2.5157, D(x): 0.70, D(G(z)): 0.23
Epoch [94/100], Step [200/600], d_loss: 0.8022, g_loss: 1.8209, D(x): 0.75, D(G(z)): 0.25
Epoch [94/100], Step [400/600], d_loss: 0.8231, g_loss: 1.4436, D(x): 0.79, D(G(z)): 0.32
Epoch [94/100], Step [600/600], d_loss: 0.7198, g_loss: 1.7707, D(x): 0.77, D(G(z)): 0.25
Epoch [95/100], Step [200/600], d_loss: 1.0768, g_loss: 1.4693, D(x): 0.68, D(G(z)): 0.28
Epoch [95/100], Step [400/600], d_loss: 0.9034, g_loss: 1.7187, D(x): 0.67, D(G(z)): 0.26
Epoch [95/