##Library

In [None]:
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from PIL import Image
from tqdm.notebook import tqdm

from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torch.utils.data import DataLoader

# os.environ['CUDA_VISIBLE_DEVICES'] = 사용하고자 하는 GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

##Hyperparameters

In [None]:
# 학습에 사용할 hyperparameter 값들을 저장하는 class
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [None]:
# 학습에 사용할 hyperparameter 값들
config = AttrDict()
config.data_path = 'data/'
config.save_path = 'save/'
config.dataset = 'MNIST'
config.n_epoch = 200
config.log_interval = 100
config.save_interval = 20
config.batch_size = 64
config.learning_rate = 2e-4

# Adam의 momentum을 위한 b1과 b2
config.b1 = 0.5
config.b2 = 0.999

config.img_shape = (1, 28, 28)
config.latent_size = 100

config.augmentation = transforms.Compose([
                        transforms.Resize((config.img_shape[1], config.img_shape[2])),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.5], std=[0.5])
                      ])
config.denormalize = lambda x: x*0.5 + 0.5
config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
if not os.path.isdir(config.data_path):
    os.makedirs(config.data_path)
if not os.path.isdir(os.path.join(config.save_path, config.data_path)):
    os.makedirs(os.path.join(config.save_path, config.data_path))

In [None]:
config.device 

device(type='cpu')

##Data load

In [None]:
if config.dataset == "MNIST":
    train_dataset = datasets.MNIST(config.data_path,
                                   train=True,
                                   download=True,
                                   transform=config.augmentation
                                   )
elif config.dataset == "CIFAR10":
    train_dataset = datasets.CIFAR10(config.data_path,
                                     train=True,
                                     download=True,
                                     transform=config.augmentation
                                     )
    
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
train_dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: data/
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.5], std=[0.5])
           )

##GAN Model

In [None]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(config.img_shape)), 512),
            # Vanishing Gradeint 완화를 위한 LeakyReLU
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        # Flatten
        img = img.reshape(img.shape[0], -1)
        validity = self.model(img)
        return validity

In [None]:
class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            *self.block(config.latent_size, 128, batchnorm=False),
            *self.block(128, 256),
            *self.block(256, 512),
            *self.block(512, 1024),
            nn.Linear(1024, int(np.prod(config.img_shape))),
            # 출력 범위를 (-1, 1)로 설정하기 위한 Tanh
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.reshape(img.shape[0], *config.img_shape)
        return img

    def block(self, input_size, output_size, batchnorm=True):
        layers = [nn.Linear(input_size, output_size)]
        if batchnorm:
            layers.append(nn.BatchNorm1d(output_size))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

###Binary Cross Entropy Loss between the target and the input probabilities

In [None]:
criterion = nn.BCELoss()

generator = Generator(config).to(config.device)
discriminator = Discriminator(config).to(config.device)

optimizer_g = torch.optim.Adam(generator.parameters(), lr=config.learning_rate, betas=(config.b1, config.b2))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=config.learning_rate, betas=(config.b1, config.b2))

In [None]:
generator.model

Sequential(
  (0): Linear(in_features=100, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Linear(in_features=128, out_features=256, bias=True)
  (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): LeakyReLU(negative_slope=0.2, inplace=True)
  (5): Linear(in_features=256, out_features=512, bias=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): LeakyReLU(negative_slope=0.2, inplace=True)
  (8): Linear(in_features=512, out_features=1024, bias=True)
  (9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): LeakyReLU(negative_slope=0.2, inplace=True)
  (11): Linear(in_features=1024, out_features=784, bias=True)
  (12): Tanh()
)

In [None]:
discriminator.model

Sequential(
  (0): Linear(in_features=784, out_features=512, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Linear(in_features=512, out_features=256, bias=True)
  (3): LeakyReLU(negative_slope=0.2, inplace=True)
  (4): Linear(in_features=256, out_features=1, bias=True)
  (5): Sigmoid()
)

##Training

In [None]:
g_loss_list = []
d_loss_list = []

for epoch in tqdm(range(config.n_epoch)):
    for i, (real_img, _) in enumerate(train_loader):
        # Train Discriminator
        real_img = real_img.to(config.device)

        valid_label = torch.ones((real_img.shape[0], 1), device=config.device, dtype=torch.float32)
        fake_label = torch.zeros((real_img.shape[0], 1), device=config.device, dtype=torch.float32)

        z = torch.randn((real_img.shape[0], config.latent_size), device=config.device, dtype=torch.float32)
        gen_img = generator(z)

        real_loss = criterion(discriminator(real_img), valid_label)

        # Discriminator를 학습시킬 때 Generator는 학습시키면 안되기 때문에, Computation Graph에서 Geneartor를 detach()를 통해 분리
        fake_loss = criterion(discriminator(gen_img.detach()), fake_label)
        
        d_loss = (real_loss + fake_loss) * 0.5

        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # Train Generator
        z = torch.randn((real_img.shape[0], config.latent_size), device=config.device, dtype=torch.float32)
        gen_img = generator(z)

        g_loss = criterion(discriminator(gen_img), valid_label)

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

        if (i+1) % config.save_interval == 0:
            g_loss_list.append(g_loss.item())
            d_loss_list.append(d_loss.item())
            print(f"Epoch [{epoch+1}/{config.n_epoch}] Batch [{i+1}/{len(train_loader)} d_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}]  ")

    if (epoch+1) % config.save_interval == 0:
        save_path = os.path.join(config.save_path, config.dataset, f"epoch_[{epoch+1}].png")
        gen_img = config.denormalize(gen_img)
        torchvision.utils.save_image(gen_img.data[:25], save_path, nrow=5, normalize=True)

In [None]:
plt.title(f"GAN training loss on {config.dataset} data")
plt.plot(g_loss_list, label='generator loss')
plt.plot(d_loss_list, label='discriminator loss')
plt.legend()
plt.show()

##Qualitative results

In [None]:
save_path = os.path.join(config.save_path, config.dataset) 
for image_path in os.listdir(save_path): 
    if image_path.endswith('.png'): 
        plt.figure(figsize=(5,5)) 
        image = Image.open(os.path.join(save_path, image_path)) 
        plt.title(image_path) 
        plt.imshow(image)
        plt.show()