In [1]:
import numpy as np
import math
import os

from collections import OrderedDict
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
LATENT_DIM = 100
IMG_SHAPE = (1, 28,28)
BATCH_SIZE = 128
LEARNING_RATE = 0.0002
BETA1 = 0.5
BETA2 = 0.999
N_CHANNELS = 1
SAMPLE_INTERVAL = 500
NUM_EPOCHS = 200

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])
train_dataset = torchvision.datasets.MNIST("../data", download = True, train = True, transform = transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4, drop_last = True)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def linear_block(in_channels, out_channels, bn = True):
            layers = [nn.Linear(in_channels, out_channels)]
            if bn:
                layers.append(nn.BatchNorm1d(out_channels))
                
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            return layers
        
        self.model = nn.Sequential(
            *linear_block(LATENT_DIM, 128, bn = False),
            *linear_block(128, 256),
            *linear_block(256, 512),
            *linear_block(512, 1024),
            nn.Linear(1024, int(np.prod(IMG_SHAPE))),
            nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        img = img.reshape(img.shape[0], *IMG_SHAPE) # (batch_size x channels x w x h)
        return img
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(IMG_SHAPE)), 512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(256, 1), # output = 1이므로 sigmoid
            nn.Sigmoid()
        )
        
    def forward(self, img):
        img_flat = img.reshape(img.shape[0], -1)
        output = self.model(img_flat)
        return output

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator().to(device)
discriminator = Discriminator().to(device)

# generator.to(device)
# discriminator.to(device)

adversarial_loss = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr = LEARNING_RATE, betas = (BETA1, BETA2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = LEARNING_RATE, betas = (BETA1, BETA2))

In [7]:
# 항상 gan을 학습하기 전에 discriminator를 위한 real_label과 fake_label을 만들어 둔다. 단 정적으로
real_label = Variable(torch.ones((BATCH_SIZE, 1)), requires_grad = False).to(device)
fake_label = Variable(torch.zeros((BATCH_SIZE, 1)), requires_grad = False).to(device)

for epoch in range(NUM_EPOCHS):
    for i, (imgs, _) in enumerate(train_loader):
        # generator train
        optimizer_G.zero_grad()
        
        z = nn.init.normal_(torch.zeros((BATCH_SIZE, LATENT_DIM))).to(device)
        
        gen_imgs = generator(z)
        gen_loss = adversarial_loss(discriminator(gen_imgs), real_label) # 생성자 입장에서는 생성 이미지를 진짜처럼 여겨야 한다. 
        
        gen_loss.backward()
        optimizer_G.step()
        
        # discriminator train
        
        optimizer_D.zero_grad()
        
        real_imgs = imgs.to(device)
        real_preds = discriminator(real_imgs)
        fake_preds = discriminator(gen_imgs.detach())
        
        dis_loss_real = adversarial_loss(real_preds, real_label)
        dis_loss_fake = adversarial_loss(fake_preds, fake_label)
        dis_loss = (dis_loss_real + dis_loss_fake) / 2
        dis_loss.backward()
        optimizer_D.step()
        
        if i % 50 == 0:
            print("EPOCH : [{}/{}], BATCH : [{}/{}], G loss : {}, D loss : {}".format(epoch, NUM_EPOCHS, i, len(train_loader), gen_loss, dis_loss))
        
        
    if epoch % 10==0:
        if not os.path.isdir("./results"):
            os.makedirs("./results", exist_ok=True)
        torchvision.utils.save_image(gen_imgs.data[:25], "./results/{:d}.png".format(epoch), nrow = 5, normalize=True)

gen imgs : torch.Size([128, 1, 28, 28])
EPOCH : [0/200], BATCH : [0/468], G loss : 0.5355530977249146, D loss : 0.6051766872406006
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1, 28, 28])
gen imgs : torch.Size([128, 1

KeyboardInterrupt: 