## Simple GAN

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
if not os.path.exists('save/sim_gan'):
    os.mkdir('save/sim_gan')

In [3]:
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.reshape(-1, 1, 28, 28)
    return out

In [4]:
num_epoches = 100
batch_size = 128
z_dimension = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # whether GPU is supportted

In [5]:
# Image processing
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

In [6]:
# MNIST dataset
mnist = datasets.MNIST('../_data/mnist', train=True, transform=img_transform, download=True)
# Data loader
dataloader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)

In [7]:
# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.dis(x)
        return x.squeeze()

In [8]:
# Generator
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())

    def forward(self, x):
        x = self.gen(x)
        return x

In [9]:
D = discriminator().to(device)
G = generator().to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [10]:
# Start training
for epoch in range(num_epoches):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        img = img.view(num_img, -1)
        real_img = img.to(device)
        real_label = torch.ones(num_img).to(device)
        fake_label = torch.zeros(num_img).to(device)
        
        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = torch.randn(num_img, z_dimension).to(device)
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = torch.randn(num_img, z_dimension).to(device)
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epoches}], d_loss: {d_loss:.6f}, g_loss: {g_loss:.6f}',
                  f'D real: {real_scores.mean():.6f}, D fake: {fake_scores.mean():.6f}')
    if epoch == 0:
        real_images = to_img(real_img)
        save_image(real_images, 'save/sim_gan/real_images.png')

    fake_images = to_img(fake_img)
    save_image(fake_images, f'save/sim_gan/fake_images-{epoch+1:0>3d}.png')

Epoch [0/100], d_loss: 0.135319, g_loss: 3.752518 D real: 0.963787, D fake: 0.091794
Epoch [0/100], d_loss: 0.031493, g_loss: 5.028942 D real: 0.987179, D fake: 0.018191
Epoch [0/100], d_loss: 0.237485, g_loss: 5.412944 D real: 0.966075, D fake: 0.164920
Epoch [0/100], d_loss: 0.034377, g_loss: 6.113669 D real: 0.996366, D fake: 0.029698
Epoch [1/100], d_loss: 0.085341, g_loss: 4.118917 D real: 0.982024, D fake: 0.046382
Epoch [1/100], d_loss: 0.200163, g_loss: 4.582250 D real: 0.965505, D fake: 0.119338
Epoch [1/100], d_loss: 0.140504, g_loss: 5.526665 D real: 0.950282, D fake: 0.050234
Epoch [1/100], d_loss: 0.547503, g_loss: 5.642103 D real: 0.927914, D fake: 0.292087
Epoch [2/100], d_loss: 0.420671, g_loss: 6.848865 D real: 0.935520, D fake: 0.193164
Epoch [2/100], d_loss: 0.818404, g_loss: 6.327253 D real: 0.859086, D fake: 0.304561
Epoch [2/100], d_loss: 0.219319, g_loss: 3.305739 D real: 0.948900, D fake: 0.132807
Epoch [2/100], d_loss: 0.440326, g_loss: 3.102167 D real: 0.86664

In [11]:
torch.save(G.state_dict(), 'save/sim_gan/generator.pytorch')
torch.save(D.state_dict(), 'save/sim_gan/discriminator.pytorch')