In [None]:
import torch
import os
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

device = 'cuda'

In [None]:
# Load MNIST dataset

# 28*28 숫자 이미지와 각 숫자가 무엇인지 Label data load
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data/mnist', train=True, download=True,
                  transform=transforms.Compose([
                      transforms.Resize(28),
                      transforms.ToTensor(),
                      transforms.Normalize([0.5], [0.5])
                  ])
                 ),
    batch_size=64, shuffle=True
)

# 시각화
def visualize(img):
    if img.shape[0] == 1:
        plt.figure(figsize=(2,2))
        img = img.squeeze()
        plt.imshow(img, cmap='gray')
        plt.show()
    elif img.shape[0]>1 and len(img.shape)>2:
        fig, axes = plt.subplots(2, 4, figsize=(12, 5))
        for i, ax in enumerate(axes.flat):
            ax.imshow(img[i].squeeze(), cmap='gray')
            ax.axis('off')
        plt.show()

In [None]:
# 데이터 체크
data = next(iter(dataloader))
print(len(data))
print(data[0].shape)
print(data[1].shape)
print(data[1][0])

visualize(data[0])

In [None]:
# Generative
# noise variable을 입력으로 받아서, 이미지를 출력한다 = 28*28
class Generator(torch.nn.Module):
    """
    Generator class for GAN
    latent vector를 입력으로 받아서 28*28 vector를 출력한다.
    """
    def __init__(self):
        super(Generator, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(100, 256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(256, 512),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(512, 784)
        )

    def forward(self, z):
        output = self.model(z)
        return output.view(-1, 1, 28, 28)

G = Generator().to(device)

In [None]:
test_latent = torch.ones(64, 100).to(device)

random_out = G(test_latent)
print(random_out.shape)
visualize(random_out[0].cpu().detach().numpy())

In [None]:
t = torch.ones(2, 1, 28, 28)
print(t.shape)

t.view(-1, 784).shape

In [None]:
# Discriminative
# 이미지를 입력으로 받아서, scalar 값을 출력한다. 실제 데이터면 0로, 가짜 데이터면 1로

class Discriminative(torch.nn.Module):
    def __init__(self, res=784):
        super(Discriminative, self).__init__()
        self.res = res
        self.model=torch.nn.Sequential(
            torch.nn.Linear(res, 256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(256, 32),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(32, 1),
            torch.nn.Sigmoid() # 값을 0~1로 제한
        )
    def forward(self, x):
        x = x.view(-1, self.res)
        clas = self.model(x)
        return clas

D = Discriminative().to(device)

In [None]:
pred = D(random_out)
print(pred.shape)

In [None]:
from tqdm import tqdm

epochs = 5
k = 1
bs = 64
criterion = torch.nn.BCELoss()

optimizer_d = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.9, 0.999))
optimizer_g = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.9, 0.999))

In [None]:
data[0].shape

pred = D(data[0].to(device))
print(pred.shape)

real_labels = torch.ones(bs, device=device)
criterion(pred.squeeze()[:4], real_labels[:4])

In [None]:
discriminator_losses = []
generative_losses = []

pbar = tqdm(total=len(dataloader), desc="Training Progress")
for epoch in range(epochs):
    d_loss = 0
    g_loss = 0
    
    for idx, data in enumerate(dataloader):
        real_images = data[0].to(device)
        if real_images.shape[0] != bs:
            continue
        # print("real_images : ", real_images.shape)
        real_labels = torch.ones(bs, device=device)
        fake_labels = torch.zeros(bs, device=device)
        for step in range(k):
            D.zero_grad()
            noise = torch.randn(bs, 100, device=device)
            fake_images = G(noise)
            real_pred = D(real_images).squeeze()
            fake_pred = D(fake_images).squeeze()

            loss_real_samples = criterion(real_pred, real_labels)
            loss_fake_samples = criterion(fake_pred, fake_labels)

            loss = (loss_real_samples + loss_fake_samples)/(bs*2)
            if idx%100==99:
                tqdm.write(f"D_loss : {d_loss.cpu().detach().numpy()/100}")
                discriminator_losses.append(d_loss.cpu().detach().numpy()/100)
                d_loss=0

            # backward and optimizer D
            loss.backward()
            optimizer_d.step()
            d_loss += loss

        G.zero_grad()
        noise = torch.randn(bs, 100, device=device)
        generated_images = G(noise)
        preds = D(generated_images).squeeze()
        loss = criterion(preds, real_labels)
        if idx%100==99:
            tqdm.write(f"G_loss : {g_loss.cpu().detach().numpy()/100}")
            generative_losses.append(g_loss.cpu().detach().numpy()/100)
            g_loss=0

        # backward and optimizer G
        loss.backward()
        optimizer_g.step()
        g_loss += loss

        pbar.update(1)

    plt.title("Discriminator loss")
    plt.plot(discriminator_losses)
    plt.show()
    plt.close()

    plt.title("Generator loss")
    plt.plot(generative_losses)
    plt.show()
    plt.close()

    # evaluate
    noise = torch.randn(12, 100, device=device)
    generated_images = G(noise)
    visualize(generated_images.cpu().detach().numpy())

    del noise
    del generated_images
    
    pbar.close()