In [1]:
import os
import torch
import torchvision

from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
save_dir = "./runs/GAN"
data_dir = "/home/pervinco/Datasets/torch_mnist"

epochs = 200
batch_size = 64
d_lr = 0.0001
g_lr = 0.0001

latent_dim = 64
hidden_dim = 256
image_size = 784

num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.model(x)

        return x

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, image_size),
            nn.Tanh()
        )

    def forward(self, z):
        z = self.model(z)
        
        return z

In [5]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5], std=[0.5])])

train_dataset = torchvision.datasets.MNIST(root=data_dir,
                                           train=True,
                                           transform=transform,
                                           download=True)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [7]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [9]:
D = Discriminator(input_dim=image_size)
G = Generator(latent_dim=latent_dim)

D = D.to(device)
G = G.to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=g_lr)

for epoch in range(epochs):
    for idx, (images, _) in enumerate(tqdm(train_dataloader, desc="Train", leave=False)):
        images = images.reshape(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        outputs = D(images)
        d_real_loss = criterion(outputs, real_labels) ## first term of object func
        real_score = outputs

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_fake_loss = criterion(outputs, fake_labels) ## second_term of object func
        fake_score = outputs

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels) ## 가짜 데이터를 입력 받았을 때 D가 진짜라고 분류한 비율을 이용해 오차를 계산한다.
        g_loss.backward()
        g_optimizer.step()

    if epoch % 10 == 0 or epoch == epochs:
        print(f"Epoch[{epoch}/{epochs}], Step[{idx+1}/{len(train_dataloader)}]")
        print(f"D_loss : {d_loss.item():.4f}, G_loss : {g_loss.item():.4f}")
        print(f"Real Score : {real_score.mean().item():.4f}, Fake Score : {fake_score.mean().item():.4f}")

    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), f"{save_dir}/real_images.png")
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), f"{save_dir}/fake_images_{epoch+1:04}.png")

# Save the model checkpoints 
torch.save(G.state_dict(), f'{save_dir}/G.ckpt')
torch.save(D.state_dict(), f'{save_dir}/D.ckpt')

                                                         

Epoch[0/200], Step[937/937]
D_loss : 0.1652, G_loss : 3.2267
Real Score : 0.9463, Fake Score : 0.0975


                                                         

Epoch[10/200], Step[937/937]
D_loss : 0.2049, G_loss : 4.8571
Real Score : 0.9793, Fake Score : 0.1384


                                                         

Epoch[20/200], Step[937/937]
D_loss : 0.0842, G_loss : 5.6074
Real Score : 0.9814, Fake Score : 0.0207


                                                         

Epoch[30/200], Step[937/937]
D_loss : 0.2079, G_loss : 3.7800
Real Score : 0.9182, Fake Score : 0.0529


                                                         

Epoch[40/200], Step[937/937]
D_loss : 0.5208, G_loss : 3.6023
Real Score : 0.8834, Fake Score : 0.1553


                                                         

Epoch[50/200], Step[937/937]
D_loss : 0.4316, G_loss : 2.5888
Real Score : 0.8474, Fake Score : 0.1272


                                                         

Epoch[60/200], Step[937/937]
D_loss : 0.7734, G_loss : 2.7718
Real Score : 0.7660, Fake Score : 0.2044


                                                         

Epoch[70/200], Step[937/937]
D_loss : 0.8416, G_loss : 1.7307
Real Score : 0.7275, Fake Score : 0.2847


                                                         

Epoch[80/200], Step[937/937]
D_loss : 0.7958, G_loss : 1.8153
Real Score : 0.7967, Fake Score : 0.3108


                                                         

Epoch[90/200], Step[937/937]
D_loss : 0.7865, G_loss : 2.1533
Real Score : 0.7181, Fake Score : 0.2346


                                                         

Epoch[100/200], Step[937/937]
D_loss : 0.8525, G_loss : 1.8530
Real Score : 0.7390, Fake Score : 0.2730


                                                         

Epoch[110/200], Step[937/937]
D_loss : 0.9031, G_loss : 1.4449
Real Score : 0.8041, Fake Score : 0.3824


                                                         

Epoch[120/200], Step[937/937]
D_loss : 0.7613, G_loss : 1.4835
Real Score : 0.7171, Fake Score : 0.2348


                                                         

Epoch[130/200], Step[937/937]
D_loss : 0.7794, G_loss : 1.4543
Real Score : 0.7581, Fake Score : 0.3079


                                                         

Epoch[140/200], Step[937/937]
D_loss : 1.2812, G_loss : 1.4295
Real Score : 0.5830, Fake Score : 0.3441


                                                         

Epoch[150/200], Step[937/937]
D_loss : 0.7842, G_loss : 1.6454
Real Score : 0.7152, Fake Score : 0.2758


                                                         

Epoch[160/200], Step[937/937]
D_loss : 0.9154, G_loss : 1.7281
Real Score : 0.7959, Fake Score : 0.3540


                                                         

Epoch[170/200], Step[937/937]
D_loss : 0.8091, G_loss : 1.6097
Real Score : 0.7327, Fake Score : 0.2903


                                                         

Epoch[180/200], Step[937/937]
D_loss : 0.8456, G_loss : 1.8595
Real Score : 0.7347, Fake Score : 0.2908


                                                         

Epoch[190/200], Step[937/937]
D_loss : 0.7821, G_loss : 1.6569
Real Score : 0.7935, Fake Score : 0.3152


                                                         

In [11]:
class AdvancedDiscriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),  # 드롭아웃 추가
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),  # 드롭아웃 추가
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.model(x)
        return x

class AdvancedGenerator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # 배치 정규화 추가
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # 배치 정규화 추가
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        z = self.model(z)
        return z


In [13]:
D = AdvancedDiscriminator(input_dim=image_size, hidden_dim=hidden_dim)
G = AdvancedGenerator(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=image_size)

D = D.to(device)
G = G.to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=g_lr)

for epoch in range(epochs):
    for idx, (images, _) in enumerate(tqdm(train_dataloader, desc="Train", leave=False)):
        images = images.reshape(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        outputs = D(images)
        d_real_loss = criterion(outputs, real_labels) ## first term of object func
        real_score = outputs

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_fake_loss = criterion(outputs, fake_labels) ## second_term of object func
        fake_score = outputs

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = G(z)
        outputs = D(fake_images)

        ## 생성자가 생성한 가짜 데이터에 대해 구분자가 이를 진짜로 판별할 확률을 최대화하는 방향으로 학습함.
        ## 가짜 데이터를 입력 받았을 때 D가 진짜라고 분류한 비율을 이용해 오차를 계산한다.
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

    if epoch % 10 == 0 or epoch == epochs:
        print(f"Epoch[{epoch}/{epochs}], Step[{idx+1}/{len(train_dataloader)}]")
        print(f"D_loss : {d_loss.item():.4f}, G_loss : {g_loss.item():.4f}")
        print(f"Real Score : {real_score.mean().item():.4f}, Fake Score : {fake_score.mean().item():.4f}")

    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), f"{save_dir}/real_images.png")
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), f"{save_dir}/fake_images_{epoch+1:04}.png")

# Save the model checkpoints 
torch.save(G.state_dict(), f'{save_dir}/G.ckpt')
torch.save(D.state_dict(), f'{save_dir}/D.ckpt')

                                                         

Epoch[0/200], Step[937/937]
D_loss : 0.6711, G_loss : 1.8578
Real Score : 0.7862, Fake Score : 0.2853


                                                         

Epoch[10/200], Step[937/937]
D_loss : 0.5473, G_loss : 2.4590
Real Score : 0.7992, Fake Score : 0.1592


                                                         

Epoch[20/200], Step[937/937]
D_loss : 0.8279, G_loss : 1.6135
Real Score : 0.7199, Fake Score : 0.2669


                                                         

Epoch[30/200], Step[937/937]
D_loss : 1.0755, G_loss : 1.3621
Real Score : 0.6025, Fake Score : 0.3305


                                                         

Epoch[40/200], Step[937/937]
D_loss : 1.0013, G_loss : 1.4986
Real Score : 0.6280, Fake Score : 0.2709


                                                         

Epoch[50/200], Step[937/937]
D_loss : 1.0589, G_loss : 1.2150
Real Score : 0.6896, Fake Score : 0.3872


                                                         

Epoch[60/200], Step[937/937]
D_loss : 0.9230, G_loss : 1.6381
Real Score : 0.6876, Fake Score : 0.2938


                                                         

Epoch[70/200], Step[937/937]
D_loss : 1.0628, G_loss : 1.5331
Real Score : 0.6235, Fake Score : 0.3280


                                                         

Epoch[80/200], Step[937/937]
D_loss : 1.0901, G_loss : 1.2250
Real Score : 0.6533, Fake Score : 0.3871


                                                         

Epoch[90/200], Step[937/937]
D_loss : 1.3654, G_loss : 0.9739
Real Score : 0.5536, Fake Score : 0.4588


                                                         

Epoch[100/200], Step[937/937]
D_loss : 1.2288, G_loss : 0.9998
Real Score : 0.5461, Fake Score : 0.3942


                                                         

Epoch[110/200], Step[937/937]
D_loss : 1.2280, G_loss : 1.0569
Real Score : 0.5313, Fake Score : 0.3914


                                                         

Epoch[120/200], Step[937/937]
D_loss : 1.2942, G_loss : 0.9153
Real Score : 0.5412, Fake Score : 0.4366


                                                         

Epoch[130/200], Step[937/937]
D_loss : 1.3524, G_loss : 0.7226
Real Score : 0.5815, Fake Score : 0.4952


                                                         

Epoch[140/200], Step[937/937]
D_loss : 1.2203, G_loss : 1.0147
Real Score : 0.5466, Fake Score : 0.4047


                                                         

Epoch[150/200], Step[937/937]
D_loss : 1.3265, G_loss : 0.9456
Real Score : 0.5296, Fake Score : 0.4531


                                                         

Epoch[160/200], Step[937/937]
D_loss : 1.1656, G_loss : 0.8242
Real Score : 0.5670, Fake Score : 0.4034


                                                         

Epoch[170/200], Step[937/937]
D_loss : 1.3170, G_loss : 0.8261
Real Score : 0.5525, Fake Score : 0.4534


                                                         

Epoch[180/200], Step[937/937]
D_loss : 1.2808, G_loss : 0.9478
Real Score : 0.5486, Fake Score : 0.4282


                                                         

Epoch[190/200], Step[937/937]
D_loss : 1.2877, G_loss : 0.9010
Real Score : 0.5261, Fake Score : 0.4299


                                                         