# 0. Env

In [None]:
import numpy as np
import PIL
from tqdm.auto import trange

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision import transforms, datasets

In [None]:
# Gradient False
# Pytorch에서 동작을 확안하기 위해서 Gradient 계산을 하지 않도록 설정
torch.set_grad_enabled(False)

# 1. C-GAN

In [None]:
# GPU 사용 가능 여부 확인
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
n_class = 10 # number of class
d_latent = 100 # latent vector z dimension
s_image = 28 * 28 # size of image

In [None]:
# 데이터 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

In [None]:
# 생성자 모델
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # 조건 임베딩
        self.embed = torch.nn.Embedding(n_class, n_class)
        # 생성자 layers
        self.layer_1 = Generator.FcBnRelu(d_latent + n_class, 128, normalize=False)
        self.layer_2 = Generator.FcBnRelu(128, 256)
        self.layer_3 = Generator.FcBnRelu(256, 512)
        self.layer_4 = Generator.FcBnRelu(512, 1024)
        self.layer_o = torch.nn.Sequential(
            torch.nn.Linear(1024, s_image),
            torch.nn.Tanh()
        )

    @staticmethod
    def FcBnRelu(d_in, d_out, normalize=True):
        layers = [torch.nn.Linear(d_in, d_out)]
        if normalize:
            layers.append(torch.nn.BatchNorm1d(d_out, 0.8))
        layers.append(torch.nn.LeakyReLU(0.2, inplace=True))
        return torch.nn.Sequential(*layers)

    def forward(self, z, y):
        # class를 벡터로 변경
        y_hidden = self.embed(y)
        # 두 벡터를 합쳐서 하나의 벡터로 변경
        hidden = torch.cat([z, y_hidden], dim=-1)
        # layer 실행
        hidden = self.layer_1(hidden)
        hidden = self.layer_2(hidden)
        hidden = self.layer_3(hidden)
        hidden = self.layer_4(hidden)
        # output
        logits = self.layer_o(hidden)
        return logits

In [None]:
# 판별자
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 조건 임베딩
        self.embed = torch.nn.Embedding(n_class, n_class)

        self.layer_1 = Discriminator.FcDoRelu(s_image + n_class, 512, dropout=0.0)
        self.layer_2 = Discriminator.FcDoRelu(512, 512)
        self.layer_3 = Discriminator.FcDoRelu(512, 512)
        self.layer_o = torch.nn.Sequential(
            torch.nn.Linear(512, 1),
            torch.nn.Sigmoid()
        )

    @staticmethod
    def FcDoRelu(d_in, d_out, dropout=0.4):
        layers = [torch.nn.Linear(d_in, d_out)]
        if dropout > 0:
            layers.append(torch.nn.Dropout(dropout))
        layers.append(torch.nn.LeakyReLU(0.2, inplace=True))
        return torch.nn.Sequential(*layers)

    def forward(self, img, y):
        # class를 벡터로 변경
        y_hidden = self.embed(y)
        # 두 벡터를 합쳐서 하나의 벡터로 변경
        hidden = torch.cat([img, y_hidden], dim=-1)
        # layer 실행
        hidden = self.layer_1(hidden)
        hidden = self.layer_2(hidden)
        hidden = self.layer_3(hidden)
        # output
        logits = self.layer_o(hidden)
        return logits

In [None]:
# generator 생성
generator = Generator()
generator.to(device)

In [None]:
# discriminator 생성
discriminator = Discriminator()
discriminator.to(device)

In [None]:
# loss 함수
loss_fn = torch.nn.BCELoss()
# optimizer
optimG = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimD = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

In [None]:
n_epoch = 100
# 전체 step 수
n_total_train = len(train_loader) * n_epoch
n_total_train

In [None]:
# 학습을 위해 gradinet 계산 활성
torch.set_grad_enabled(True)
p_bar = trange(n_total_train)

for epoch in range(100):
    train_d_loss, train_g_loss = [], []
    # train
    generator.train()
    discriminator.train()
    for images, labels in train_loader:
        # real y and images
        real_img, real_y = images.to(device), labels.to(device)
        n_batch = real_y.shape[0]
        real_img = real_img.view(n_batch, -1)

        # fake y and latent z
        fake_y = torch.randint(0, 10, (n_batch, )).to(device)
        z = torch.tensor(np.random.normal(0, 1, (n_batch, d_latent))).type(torch.float).to(device)

        # real label & fake label
        real_labels = torch.ones(n_batch).to(device)
        fake_labels = torch.zeros(n_batch).to(device)

        ##########################################################
        # train discirminator
        ##########################################################
        optimD.zero_grad()

        # loss for real images predict real
        real_logits = discriminator(real_img, real_y)
        d_real_loss = loss_fn(real_logits.view(-1), real_labels)

        # loss for fake images predict fake
        fake_img = generator(z, fake_y)
        fake_logits = discriminator(fake_img.detach(), fake_y)
        d_fake_loss = loss_fn(fake_logits.view(-1), fake_labels)

        # loss
        d_loss = (d_real_loss + d_fake_loss)
        train_d_loss.append(d_loss.item())

        # update
        d_loss.backward()
        optimD.step()

        ##########################################################
        # train generator
        ##########################################################
        optimG.zero_grad()

        # loss for fake images predict real
        fake_img = generator(z, fake_y)
        fake_logits = discriminator(fake_img, fake_y)
        g_loss = loss_fn(fake_logits.view(-1), real_labels)
        train_g_loss.append(g_loss.item())

        # update
        g_loss.backward()
        optimG.step()

        # display progress
        p_bar.set_description(f'train epoch: {epoch + 1:3d}, d_loss: {np.mean(train_d_loss):.4f}, g_loss: {np.mean(train_g_loss):.4f}')
        p_bar.update(1)

In [None]:
def do_generate(generator, y):
    generator.eval()
    with torch.no_grad():
        # 숫자별로 1개 이미지 생성을 위한 입력
        fake_y = torch.tensor([y]).to(device)
        z = torch.tensor(np.random.normal(0, 1, (1, d_latent))).type(torch.float).to(device)
        # 이미지 생성
        fake_img = generator(z, fake_y)
        # numpy array로 변경
        fake_img = fake_img.view(28, 28)
        fake_img = fake_img.cpu().detach().numpy()
    # de normalize
    fake_img = (fake_img * 0.5) + 0.5
    fake_img = fake_img * 255.
    # concat all image
    fake_img = fake_img.astype(np.ubyte)
    display(PIL.Image.fromarray(fake_img))

In [None]:
while True:
    string = input('번호 (0 ~ 9) > ')
    string = string.strip()
    if len(string) == 0:
        break
    y = int(string)
    if 0 <= y <= 9:
        do_generate(generator, y)