In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z, going into a convolution (c x 100 x 1)
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), # 100차원의 값을 받아 512차원으로 늘리면서 512x4x4로 변경시켜줌
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State size. c x 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State size. c x 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State size. c x 128 x 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State size. 64 x 32 x 32
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # State size. 3 x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input is 3 x 64 x 64
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. 64 x 32 x 32
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. 128 x 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. 256 x 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. 512 x 4 x 4
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def show_generated_images(images, num_images=64):
    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.title("Generated Images")
    images = vutils.make_grid(images[:num_images], padding=2, normalize=True)
    images = np.transpose(images.cpu(), (1, 2, 0))
    plt.imshow(images)
    plt.show()

def save_generated_images(images, num_images, epoch, idx):
    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.title("Generated Images")
    images = vutils.make_grid(images[:num_images], padding=2, normalize=True)
    images = np.transpose(images.cpu(), (1, 2, 0))
    # plt.imshow(images)
    fname = './output/image_'+str(epoch)+'_'+str(idx)+'.jpg'
    plt.imsave(fname, images.numpy())
    plt.close()

In [None]:
# Create the generator and the discriminator
netG = Generator()
netD = Discriminator()

In [None]:
# Apply the weights_init function to randomly initialize all weights
netG.apply(weights_init)
netD.apply(weights_init)

In [None]:
# Data Preparation
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64), # 얼굴이 보통 중간에 있으니 중간을 잘라서 학습하는 것이 좋음
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # 데이터의 범위를 [-1,1]로 지정 
])

dataset = datasets.ImageFolder(root='./data/celeba', transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
# Training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG.to(device)
netD.to(device)

criterion = nn.BCELoss() # 이진 크로스 엔트로피 사용
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

num_epochs = 10  # For demonstration purposes; increase for better results

fixed_noise = torch.randn(64, 100, 1, 1, device=device)         # intermediate visualization

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()
        real_data = data[0].to(device)
        batch_size = real_data.size(0)
        real_label = torch.full((batch_size,), 1, dtype=torch.float, device=device) # full-> 입력된 para size와 동일한 텐서 생성하고 지정한 값으로 채움
        fake_label = torch.full((batch_size,), 0, dtype=torch.float, device=device)

        # Forward pass real batch through D
        output = netD(real_data).view(-1)
        errD_real = criterion(output, real_label)
        errD_real.backward()

        # Generate fake image batch with G
        # noise = torch.randn(batch_size, 100, device=device)
        noise = torch.randn(batch_size, 100, 1,1, device=device)
        fake_data = netG(noise)
        output = netD(fake_data.detach()).view(-1)

        # Calculate D's loss on the fake batch
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()

        # Add the gradients from the real and fake batches
        errD = errD_real + errD_fake
        optimizerD.step()

        # Update Generator: maximize log(D(G(z)))
        netG.zero_grad()
        output = netD(fake_data).view(-1)
        errG = criterion(output, real_label)
        errG.backward()
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                  % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item()))
            # fixed_noise = torch.randn(64, 100, 1, 1, device=device)
            fake_images = netG(fixed_noise)
            save_generated_images(fake_images, 64, epoch=epoch, idx=i)

In [None]:
# Visualization function
fake_images = netG(fixed_noise)
show_generated_images(fake_images)
