# GAN 실험해보기
- Pytorch로 GAN 구조를 짜보고, MNIST digit으로 학습하여, gan 의 generator 가 제대로 동작하는지 확인해보겠습니다.

### 1. Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

from generator import Generator
from discriminator import Discriminator
from train import train_model

import matplotlib.pyplot as plt

# device setting for gpu users
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: ", device)
torch.backends.cudnn.enabled = False

device:  cuda


### 2. Data Preparation
MNIST digit data 를 활용하겠습니다.

In [2]:
epochs = 300
batch_size = 128
z_dim = 100

transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(
                                   mean=(0.5,),
                                   std=(0.5,))])

mnist_dataset = datasets.MNIST(root='./data/', train=True, transform=transform, download=True)

dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

### 3. Modeling
깔끔한 노트북을 위해, `discriminator.py` 와 `generator.py` 에 각 Discriminator 와 Generator를 정의해 두었습니다. 이 노트북에서는 초기화 선언만 하겠습니다.

In [3]:
generator = Generator(latent_dims=z_dim).to(device)
discriminator = Discriminator().to(device)
print("GENERATOR : ", generator)
print("DISCRIMINATOR : ", discriminator)

GENERATOR :  Generator(
  (fc1): Linear(in_features=100, out_features=128, bias=True)
  (fc1_bn): BatchNorm1d(128, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=256, bias=True)
  (fc2_bn): BatchNorm1d(256, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=256, out_features=512, bias=True)
  (fc3_bn): BatchNorm1d(512, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=512, out_features=1024, bias=True)
  (fc4_bn): BatchNorm1d(1024, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (fc5): Linear(in_features=1024, out_features=784, bias=True)
)
DISCRIMINATOR :  Discriminator(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=1, bias=True)
)


### 4. Train
#### 4-1. Loss & Optimizer
- GAN 의 구조에서부터 알 수 있듯이, 지금 네트워크는 discriminator 가 generator 로 부터 받은 생성된 사진과 실제 사진이 각각 진짜인지, 가짜인지 맞추는 loss 로 부터 역전파 되어 각 구조가 학습하게 됩니다. 따라서 discriminator 의 마지막 layer의 크기와 Bincary Cross Entropy Loss 가 구조로부터 정해지게 됩니다.
- Optimizer 의 경우, 우리는 discriminator 와 generator 가 순차적으로 학습하는 구조를 가질 수 밖에 없습니다. 따라서, 각 구조를 update 시켜주기위한 optimizer 는 따로 선언해 줍니다.

In [4]:
# Loss & Optimizer
criterion = nn.BCELoss()
generator_optim = optim.Adam(generator.parameters(), lr=0.0002, weight_decay=8e-9)
discriminator_optim = optim.Adam(discriminator.parameters(), lr=0.0002, weight_decay=8e-9)

#### 4-2. Train Model
- discriminator 의 학습을 위해 train 단계에서, 진짜(1)와 가짜(0) 이미지의 label을 붙여줍니다.

In [5]:
train_model(z_dim, discriminator, generator, batch_size, discriminator_optim, generator_optim, criterion, dataloader, epochs, device)

EPOCH 0: BATCH: 467, discrim_loss: 0.12506286799907684, generator_loss: 3.440580368041992
EPOCH 20: BATCH: 467, discrim_loss: 1.41682767868042, generator_loss: 2.9967875480651855
EPOCH 40: BATCH: 467, discrim_loss: 0.8488105535507202, generator_loss: 2.2675833702087402
EPOCH 60: BATCH: 467, discrim_loss: 1.3603236675262451, generator_loss: 2.0868096351623535
EPOCH 80: BATCH: 467, discrim_loss: 0.8660893440246582, generator_loss: 3.3595690727233887
EPOCH 100: BATCH: 467, discrim_loss: 1.0487098693847656, generator_loss: 1.2499107122421265
EPOCH 120: BATCH: 467, discrim_loss: 1.1544331312179565, generator_loss: 2.6048731803894043
EPOCH 140: BATCH: 467, discrim_loss: 0.9980860352516174, generator_loss: 1.420907735824585
EPOCH 160: BATCH: 467, discrim_loss: 1.3974876403808594, generator_loss: 2.2171359062194824
EPOCH 180: BATCH: 467, discrim_loss: 1.286292314529419, generator_loss: 1.0285342931747437
EPOCH 200: BATCH: 467, discrim_loss: 1.2533783912658691, generator_loss: 1.085643291473388

(Discriminator(
   (fc1): Linear(in_features=784, out_features=256, bias=True)
   (fc2): Linear(in_features=256, out_features=128, bias=True)
   (fc3): Linear(in_features=128, out_features=1, bias=True)
 ),
 Generator(
   (fc1): Linear(in_features=100, out_features=128, bias=True)
   (fc1_bn): BatchNorm1d(128, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
   (fc2): Linear(in_features=128, out_features=256, bias=True)
   (fc2_bn): BatchNorm1d(256, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
   (fc3): Linear(in_features=256, out_features=512, bias=True)
   (fc3_bn): BatchNorm1d(512, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
   (fc4): Linear(in_features=512, out_features=1024, bias=True)
   (fc4_bn): BatchNorm1d(1024, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
   (fc5): Linear(in_features=1024, out_features=784, bias=True)
 ))