In [None]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torchsummary import summary
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import utils as vutils

In [None]:
num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
save_dir = "./runs/WGAN"
data_dir = "/home/pervinco/Datasets/CelebA"

epochs = 100
batch_size = 128
lr = 0.0002
beta1 = 0.5 
beta2 = 0.999

n_critics = 5
weight_cliping_limit = 0.01

img_size = 64

nc = 3
nz = 100
ngf = 64
ndf = 64

In [None]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
dataset = datasets.ImageFolder(root=data_dir,
                                transform=transforms.Compose([
                                    transforms.Resize(img_size),
                                    # transforms.CenterCrop(img_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

one_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Training Images')
plt.imshow(np.transpose(vutils.make_grid(one_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, channels):
        super().__init__()
        # Filters [1024, 512, 256]
        # Input_dim = 100
        # Output_dim = C (number of channels)
        self.main_module = nn.Sequential(
            # Z latent vector 100
            nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(num_features=1024),
            nn.ReLU(True),

            # State (1024x4x4)
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(True),

            # State (512x8x8)
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(True),

            # State (256x16x16)
            nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1))
            # output of main module --> Image (Cx32x32)

        self.output = nn.Tanh()

    def forward(self, x):
        x = self.main_module(x)
        return self.output(x)

In [None]:
G = Generator(channels=nz).to(device)
G.apply(weights_init)
summary(G, (nz, 1, 1))

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, channels):
        super().__init__()
        # Filters [256, 512, 1024]
        # Input_dim = channels (Cx64x64)
        # Output_dim = 1
        self.main_module = nn.Sequential(
            # Image (Cx32x32)
            nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(0.2, inplace=True),

            # State (256x16x16)
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.LeakyReLU(0.2, inplace=True),

            # State (512x8x8)
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=1024),
            nn.LeakyReLU(0.2, inplace=True))
            # output of main module --> State (1024x4x4)

        self.output = nn.Sequential(
            # The output of D is no longer a probability, we do not apply sigmoid at the output of D.
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))


    def forward(self, x):
        x = self.main_module(x)
        return self.output(x)

    def feature_extraction(self, x):
        # Use discriminator for feature extraction then flatten to vector of 16384
        x = self.main_module(x)
        return x.view(-1, 1024*4*4)

In [None]:
D = Discriminator(channels=nc).to(device)
D.apply(weights_init)
summary(D, (3, 64, 64))

In [None]:
def save_fake_images(epoch, G, fixed_noise, num_images=64):
    with torch.no_grad():  # 그래디언트 계산을 하지 않음
        fake_images = G(fixed_noise).detach().cpu()
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"Fake Images at Epoch {epoch}")
    plt.imshow(np.transpose(vutils.make_grid(fake_images[:num_images], padding=2, normalize=True), (1, 2, 0)))
    plt.savefig(f"{save_dir}/Epoch_{epoch}_Fake.png")  # 이미지 파일로 저장
    plt.close(fig)

In [None]:
D = Discriminator(channels=nc).to(device)
G = Generator(channels=nc).to(device)
D.apply(weights_init)
G.apply(weights_init)

optimizerD = torch.optim.RMSprop(D.parameters(), lr=lr)
optimizerG = torch.optim.RMSprop(G.parameters(), lr=lr)

G.train()
D.train()
one = torch.tensor(1, dtype=torch.float).to(device)
mone = (one * -1).to(device)
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
for epoch in range(epochs):
    for idx, (images, _) in enumerate(tqdm(dataloader, desc="Train", leave=False)):
        D.zero_grad()
        images = images.to(device)
        
        # 진짜 이미지를 사용하여 판별자 학습
        d_loss_real = D(images).mean()
        d_loss_real.backward(one)
        
        # 가짜 이미지 생성
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = G(z)
        
        # 가짜 이미지를 사용하여 판별자 학습
        d_loss_fake = D(fake_images.detach()).mean()
        d_loss_fake.backward(mone)
        
        d_loss = d_loss_fake - d_loss_real
        optimizerD.step()
        
        # Generator 학습
        G.zero_grad()
        
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = G(z)
        g_loss = D(fake_images).mean()
        g_loss.backward(one)
        optimizerG.step()
        
    print(f'Epoch: {epoch}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}')        
    save_fake_images(epoch, G, fixed_noise)

torch.save(G.state_dict(), f'{save_dir}/generator_epoch_{epoch}.pth')
torch.save(D.state_dict(), f'{save_dir}/discriminator_epoch_{epoch}.pth')