In [None]:
from torchvision import transforms
from torchvision.datasets import MNIST
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from torch.autograd.variable import Variable

# MNIST Dataset

In [None]:
download_root = './data'

In [None]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(), # 데이터를 파이토치 Tensor 형식으로 변환
    transforms.Normalize((0.5,), (0.5,)) # -1 ~ 1 사이의 값으로 Normalize
])

In [None]:
train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
valid_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)

# Hyperparameter

- **batch_size** : 배치 사이즈
- **epochs** : Learning Epochs
- **lr_D** : Discriminator의 learning rate
- **lr_G** : Generator의 learning rate
- **dropout_ratio** : Discriminator와 Generator의 dropout ratio
- **betas** : optimizer의 coefficients

In [None]:
batch_size = 128
epochs = 100
lr_D = 2e-4
lr_G = 2e-4
dropout_ratio = 0.3
betas = (0.5, 0.999)

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def plot_digit(x, y=None, n_rows=2, n_cols=5):
    n = n_rows*n_cols
    
    x = x / 2 + 0.5
    x = x.view(-1, 28, 28)
    x = x * 255
    
    x = x.numpy()
    
    x = x[:n]
    if y is not None:
        y = y[:n]
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(1.5*n_cols, 2*n_rows))
    
    for i in range(n):
        if len(x) <= i:
            break
        ax = axes[i//n_cols, i%n_cols]
        ax.imshow(x[i], cmap='gray')
        if y is not None:
            ax.set_title(f'Label: {y[i]}')
            
    plt.show()

In [None]:
for i, (x, y) in enumerate(train_loader):
    plot_digit(x, y, 2, 5)
    break

# Model

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_layers, output_dim, output_channel, dropout_ratio):
        super(Generator, self).__init__()
        
        hidden_layers = [input_dim, *hidden_layers]
        
        self.output_dim = output_dim
        self.output_channel = output_channel
        
        self.hidden_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_layers[i], hidden_layers[i+1]),
                nn.LeakyReLU(0.2),
                nn.Dropout(p=dropout_ratio)
            ) for i in range(len(hidden_layers)-1)])
        self.output_layer = nn.Linear(hidden_layers[-1], output_dim*output_channel)
        self.tanh = nn.Tanh()
        
        for layer in self.hidden_layers:
            layer.apply(init_weights)
        
    def forward(self, x):
        # x: [batch_size, input_dim]
        
        # x: [batch_size, hidden_layers[-1]]
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
            
        # x: [batch_size, ouput_dim*output_channel]
        # x: [batch_size, ouput_dim*output_channel]
        x = self.output_layer(x)
        x = self.tanh(x)
        
        # x: [batch_size, output_channel, output_dim]
        if type(self.output_dim) == int:
            x = x.view(-1, self.output_channel, self.output_dim)
        else:
            x = x.view(-1, self.output_channel, *self.output_dim)
            
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, input_channel, hidden_layers, dropout_ratio):
        super(Discriminator, self).__init__()
        
        if type(input_dim) == tuple:
            input_dim = np.prod(input_dim)
            
        hidden_layers = [input_dim*input_channel, *hidden_layers]
        
        self.input_dim = input_dim
        self.input_channel = input_channel
        
        self.hidden_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_layers[i], hidden_layers[i+1]),
                nn.LeakyReLU(0.2),
                nn.Dropout(p=dropout_ratio)
            ) for i in range(len(hidden_layers)-1)])
        self.output_layer = nn.Linear(hidden_layers[-1], 1)
        self.sigmoid = nn.Sigmoid()
        
        for layer in self.hidden_layers:
            layer.apply(init_weights)
        
    def forward(self, x):
        # x: [batch_size, input_channel, input_dim]
        
        # x: [batch_size, input_channel*input_dim]
        x = x.view(-1, self.input_channel*self.input_dim)
        
        # x: [batch_size, hidden_layers[-1]]
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
        
        # x: [batch_size, 1]
        # x: [batch_size, 1]
        x = self.output_layer(x)
        x = self.sigmoid(x)
        
        return x

# Training

In [None]:
generator = Generator(input_dim=100, hidden_layers=[256,512,1024], 
                      output_dim=(28*28), output_channel=1, dropout_ratio=dropout_ratio)
discriminator = Discriminator(input_dim=(28*28), input_channel=1, 
                              hidden_layers=[1024,512,256], dropout_ratio=dropout_ratio)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr_G, betas=betas)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_D, betas=betas)
loss_func = nn.BCELoss()

In [None]:
generator = generator.to(device)
discriminator = discriminator.to(device)

In [None]:
train_loss_D, valid_loss_D = [], []
train_loss_G, valid_loss_G = [], []

for epoch in tqdm(range(epochs), desc='Epoch'):
    print(f'[Epoch: {epoch}]')
    
    generator.train()
    discriminator.train()
    
    losses_D = []
    losses_G = []
    
    # Train Network
    for xb, _ in tqdm(train_loader):
        bs = len(xb)
        xb = xb.to(device)

        target_real = Variable(torch.ones(bs, 1)).to(device)
        target_fake = Variable(torch.zeros(bs, 1)).to(device)
        
        # Update Discriminator
        predicted = discriminator(xb)
        loss_D_real = loss_func(predicted, target_real)
        
        noise = Variable(torch.randn(bs, 100)).to(device)
        xb_fake = generator(noise)
        predicted = discriminator(xb_fake)
        loss_D_fake = loss_func(predicted, target_fake)
        
        loss_D = loss_D_real + loss_D_fake
        losses_D.append(loss_D.item())
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        # Update Generator
        noise = Variable(torch.randn(bs, 100)).to(device)
        xb_fake = generator(noise)
        predicted = discriminator(xb_fake)
        loss_G = loss_func(predicted, target_real)
        losses_G.append(loss_G.item())
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
    
    avg_loss_D = sum(losses_D) / len(losses_D)
    avg_loss_G = sum(losses_G) / len(losses_G)
    print(f'[TRAIN] avg_loss_G: {avg_loss_G:.5f}, avg_loss_D: {avg_loss_D:.5f}')
    
    generator.eval()
    discriminator.eval()
    
    losses_D = []
    losses_G = []
    
    # Validate Network
    with torch.no_grad():
        for i, (xb, _) in enumerate(tqdm(valid_loader)):
            bs = len(xb)
            xb = xb.to(device)

            target_real = Variable(torch.ones(bs, 1)).to(device)
            target_fake = Variable(torch.zeros(bs, 1)).to(device)
            
            predicted = discriminator(xb)
            loss_D_real = loss_func(predicted, target_real)
            
            noise = Variable(torch.randn(bs, 100)).to(device)
            xb_fake = generator(noise)
            predicted = discriminator(xb_fake)
            loss_D_fake = loss_func(predicted, target_fake)
            
            loss_D = loss_D_real + loss_D_fake
            loss_G = loss_func(predicted, target_real)
            
            losses_D.append(loss_D.item())
            losses_G.append(loss_G.item())
            
            if i==0:
                plot_digit(xb_fake.cpu(), n_rows=2, n_cols=5)
            
    avg_loss_D = sum(losses_D) / len(losses_D)
    avg_loss_G = sum(losses_G) / len(losses_G)
    print(f'[VALID] avg_loss_G: {avg_loss_G:.5f}, avg_loss_D: {avg_loss_D:.5f}')