## Simple GAN

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
if not os.path.exists('save/sim_gan'):
    os.mkdir('save/sim_gan')

In [3]:
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.reshape(-1, 1, 28, 28)
    return out

In [4]:
num_epoches = 100
batch_size = 128
z_dimension = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # whether GPU is supportted

In [5]:
# Image processing
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

In [6]:
# MNIST dataset
mnist = datasets.MNIST('../_data/mnist', train=True, transform=img_transform, download=True)
# Data loader
dataloader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)

In [7]:
# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.dis(x)
        return x.squeeze()

In [8]:
# Generator
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())

    def forward(self, x):
        x = self.gen(x)
        return x

In [9]:
D = discriminator().to(device)
G = generator().to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [10]:
# Start training
for epoch in range(num_epoches):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        img = img.view(num_img, -1)
        real_img = img.to(device)
        real_label = torch.ones(num_img).to(device)
        fake_label = torch.zeros(num_img).to(device)
        
        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = torch.randn(num_img, z_dimension).to(device)
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = torch.randn(num_img, z_dimension).to(device)
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch}/{num_epoches}], d_loss: {d_loss:.6f}, g_loss: {g_loss:.6f}',
                  f'D real: {real_scores.mean():.6f}, D fake: {fake_scores.mean():.6f}')
    if epoch == 0:
        real_images = to_img(real_img)
        save_image(real_images, 'save/sim_gan/real_images.png')

    fake_images = to_img(fake_img)
    save_image(fake_images, f'save/sim_gan/fake_images-{epoch+1:0>3d}.png')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/100], d_loss: 0.136262, g_loss: 3.331758 D real: 0.976119, D fake: 0.104650
Epoch [0/100], d_loss: 0.054924, g_loss: 3.838619 D real: 0.988118, D fake: 0.041765
Epoch [0/100], d_loss: 0.149396, g_loss: 5.615974 D real: 0.975778, D fake: 0.096019
Epoch [0/100], d_loss: 0.076128, g_loss: 4.687578 D real: 0.963456, D fake: 0.035739


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [1/100], d_loss: 0.081038, g_loss: 5.715858 D real: 0.978861, D fake: 0.046821
Epoch [1/100], d_loss: 0.042232, g_loss: 5.419573 D real: 0.979436, D fake: 0.015044
Epoch [1/100], d_loss: 0.614751, g_loss: 5.351552 D real: 0.911668, D fake: 0.300150
Epoch [1/100], d_loss: 0.254833, g_loss: 6.752466 D real: 0.913134, D fake: 0.099508
Epoch [2/100], d_loss: 0.171637, g_loss: 5.586302 D real: 0.963307, D fake: 0.102734
Epoch [2/100], d_loss: 0.250669, g_loss: 7.750484 D real: 0.935755, D fake: 0.090065
Epoch [2/100], d_loss: 0.594251, g_loss: 3.805933 D real: 0.787677, D fake: 0.088503
Epoch [2/100], d_loss: 0.512707, g_loss: 4.628797 D real: 0.793658, D fake: 0.059797
Epoch [3/100], d_loss: 0.412490, g_loss: 4.082191 D real: 0.910901, D fake: 0.199935
Epoch [3/100], d_loss: 0.355417, g_loss: 4.383627 D real: 0.919288, D fake: 0.151899
Epoch [3/100], d_loss: 1.286704, g_loss: 3.416663 D real: 0.779357, D fake: 0.385552
Epoch [3/100], d_loss: 0.523269, g_loss: 2.728957 D real: 0.92171

Exception ignored in: <bound method Image.__del__ of <PIL.Image.Image image mode=L size=28x28 at 0x7F9E5D25C7F0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/PIL/Image.py", line 590, in __del__
    self.fp = None
KeyboardInterrupt


Epoch [4/100], d_loss: 0.258868, g_loss: 3.815355 D real: 0.918375, D fake: 0.121004
Epoch [4/100], d_loss: 0.768087, g_loss: 2.793466 D real: 0.815595, D fake: 0.282213
Epoch [5/100], d_loss: 0.241002, g_loss: 3.836269 D real: 0.907160, D fake: 0.080332
Epoch [5/100], d_loss: 0.396196, g_loss: 3.545696 D real: 0.889451, D fake: 0.139526
Epoch [5/100], d_loss: 0.217994, g_loss: 3.506589 D real: 0.911865, D fake: 0.052661
Epoch [5/100], d_loss: 0.158857, g_loss: 4.374750 D real: 0.936575, D fake: 0.050229
Epoch [6/100], d_loss: 0.091294, g_loss: 3.987708 D real: 0.972714, D fake: 0.053590
Epoch [6/100], d_loss: 0.082062, g_loss: 5.622046 D real: 0.971301, D fake: 0.031754
Epoch [6/100], d_loss: 0.326117, g_loss: 2.916009 D real: 0.877980, D fake: 0.052647
Epoch [6/100], d_loss: 0.251813, g_loss: 2.282584 D real: 0.925859, D fake: 0.088611
Epoch [7/100], d_loss: 0.114940, g_loss: 3.544911 D real: 0.972035, D fake: 0.068966
Epoch [7/100], d_loss: 0.197274, g_loss: 6.993427 D real: 0.94467

KeyboardInterrupt: 

In [None]:
torch.save(G.state_dict(), 'save/sim_gan/generator.pytorch')
torch.save(D.state_dict(), 'save/sim_gan/discriminator.pytorch')