# GAN의 구현

1. import  
2. Generator  
3. Discriminator  
4. 기타  
5. loss함수  
6. image display  
7. training  

## (1) import

In [31]:
# conda install -c engility pytorch
# https://pytorch.org/ 
# conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

In [32]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!

<torch._C.Generator at 0x7fbdfcb169b0>

## (2) Generator
### (2-1) generator block

In [45]:
def gen_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

### (2-2) generator의 구조


In [46]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            gen_block(z_dim, hidden_dim),
            gen_block(hidden_dim, hidden_dim * 2),
            gen_block(hidden_dim * 2, hidden_dim * 4),
            gen_block(hidden_dim * 4, hidden_dim * 8),

            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )

    def forward(self, noise):
        return self.gen(noise)
    
    def get_gen(self):
        return self.gen

## (3) discriminator 
### (3-1) discriminator block

In [47]:
def dis_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace=True)
    )

### (3-2) discriminator의 구조

In [48]:
class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            dis_block(im_dim, hidden_dim * 4),
            dis_block(hidden_dim * 4, hidden_dim * 2),
            dis_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, image):
        return self.disc(image)

    def get_disc(self):
        return self.disc

## (4) 기타
### (4-1) 노이즈 함수


In [49]:
def get_noise(n_samples, z_dim, device='cpu'):  # gpu , cpu
    return torch.randn(n_samples, z_dim, device = device)

### (4-2) 파라미터 셋업

In [50]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cpu'  # cuda, cpu

### (4-3) 데이터 로딩

In [51]:
dataloader = DataLoader(
    MNIST('.', download = True, transform = transforms.ToTensor()),
    batch_size = batch_size,
    shuffle = True)

## (5) loss 함수
### (5-1) optimizer

In [52]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

### (5-2) disc loss

In [53]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)  # z
    fake = gen(fake_noise)                                    # G(z)
    disc_fake_pred = disc(fake.detach())                      # D(G(z))
    # compare fake_pred & zeros
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)                             # D(x)
    # compare real_pred & ones
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2

    return disc_loss

### (5-3) gen loss

In [54]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake) # fake.detach(): gpu로 불러왔던것을 cpu에서 연산하겠다는 뜻으로 detach 씀. 처음부터 cpu로 돌린다면 사용 X
    # compare fake_pred & ones
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

    return gen_loss

## (6) image display

In [43]:
def show_tensor_images(image_tensor, num_images = 25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow = 5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

## (7) training

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True
gen_loss = False
error = False

# image_tensor: the images to show

for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)
        
        # update discriminator
        disc_opt.zero_grad()
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        disc_loss.backward(retain_graph = True)
        disc_opt.step()
        
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()
        
        # update generator
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                #print("Runtime tests have failed.")
            
        mean_discriminator_loss += disc_loss.item() / display_step
        mean_generator_loss += gen_loss.item() / display_step

        if cur_step % display_step == 0 and cur_step > 0:
            print(f"step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss:{mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device = device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

        