In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from torchvision import datasets, transforms

In [2]:
data_dir = "/home/pervinco/Datasets/torch_mnist"
save_dir = "./runs/cDCGAN"

epochs = 20
batch_size = 128
lr = 0.0002
beta1 = 0.5

n_classes = 10
image_size = 64
x_dim = 1
z_dim = 100
d_dim = 64  # 판별자의 특성 맵 크기
g_dim = 64  # 생성자의 특성 맵 크기

num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [4]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Normalize([0.5], [0.5])
])

# train_dataset = ImageFolder(root=data_dir, transform=transfomr)
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

train_dataset = datasets.MNIST(root=data_dir, train=True, transform=transform, download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [6]:
## cGAN + DCGAN
class Generator(nn.Module):
    def __init__(self, n_classes, z_dim, img_size, g_dim, x_dim):
        super().__init__()
        self.label_embedding = nn.Embedding(n_classes, n_classes)

        input_dim = z_dim + n_classes
        self.model = nn.Sequential(
            nn.ConvTranspose2d(input_dim, (g_dim * 8), 4, 1, 0, bias=False), ## (512, 4, 4)
            nn.BatchNorm2d(g_dim * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d((g_dim * 8), (g_dim * 4), 4, 2, 1, bias=False), ## (256, 8, 8)
            nn.BatchNorm2d(g_dim * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d((g_dim * 4), (g_dim * 2), 4, 2, 1, bias=False), ## (128, 16, 16)
            nn.BatchNorm2d(g_dim * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d((g_dim * 2), g_dim, 4, 2, 1, bias=False), ## (64, 32, 32)
            nn.BatchNorm2d(g_dim),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(img_size, x_dim, 4, 2, 1, bias=False), ## (1, 64, 64)
            nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_embedding(labels) ## [batch_size, n_classes]
        labels = labels.unsqueeze(-1).unsqueeze(-1) ## [batch_size, n_classes, 1, 1]
        # noise와 labels의 차원이 일치하도록 조정
        x = torch.cat([noise, labels.expand(-1, -1, noise.size(2), noise.size(3))], 1)

        return self.model(x)

In [8]:
G = Generator(n_classes=10, z_dim=100, img_size=64, g_dim=64, x_dim=1)
z = torch.randn(1, z_dim, 1, 1)
y = torch.randint(low=0, high=10, size=(1,), dtype=torch.int32)

Gz = G(z, y)
print(Gz.shape)

torch.Size([1, 1, 64, 64])


In [9]:
class Discriminator(nn.Module):
    def __init__(self, n_classes, x_dim, d_dim):
        super().__init__()
        self.label_embedding = nn.Embedding(n_classes, n_classes)

        input_dim = x_dim + n_classes
        self.model = nn.Sequential(
            nn.Conv2d(input_dim, d_dim, 4, 2, 1, bias=False), ## (64, 32, 32)
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(d_dim, (d_dim * 2), 4, 2, 1, bias=False), ## (128, 16, 16)
            nn.BatchNorm2d(d_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d((d_dim * 2), (d_dim * 4), 4, 2, 1, bias=False), ## (256, 8, 8)
            nn.BatchNorm2d(d_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d((d_dim * 4), (d_dim * 8), 4, 2, 1, bias=False), ## (512, 4, 4)
            nn.BatchNorm2d((d_dim * 8)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d((d_dim * 8), 1, 4, 1, 0, bias=False), ## (1, 1, 1)
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        labels = self.label_embedding(labels) ## [batch_size, n_classes]
        labels = labels.unsqueeze(2).unsqueeze(3).expand(-1, -1, img.size(2), img.size(3))  # [batch_size, n_classes, height, width]
        
        img = torch.cat([img, labels], 1)
        return self.model(img).view(-1, 1).squeeze(1)

In [10]:
D = Discriminator(n_classes=10, x_dim=1, d_dim=64)
x = torch.randn(1, 1, 64, 64)
y = torch.randint(low=0, high=10, size=(1,), dtype=torch.int32)

Dx = D(x, y)
print(Dx.shape)

torch.Size([1])


In [10]:
def train(D, G, dataloader, d_optimizer, g_optimizer, criterion, device):
    D_losses, G_losses = [], []
    D_real_accuracies, D_fake_accuracies_before_update, D_fake_accuracies_after_update = [], [], []
    for idx, (images, labels) in enumerate(tqdm(dataloader, desc="Train", leave=True)):
        real_x = images.to(device)
        real_y = labels.to(device)
        bs = real_x.size(0)

        real_labels = torch.ones(bs).to(device)
        fake_labels = torch.zeros(bs).to(device)

        D.zero_grad()
        Gx = D(real_x, real_y).view(-1)  ## [128, 1, 1, 1] --> [128]
        d_real_loss = criterion(Gx, real_labels)
        d_real_loss.backward()
        D_real_accuracies.append((Gx > 0.5).float().mean().item())

        z = torch.randn(bs, z_dim, 1, 1, device=device) ## [batch_size, 100, 1, 1]
        Gz = G(z, real_y) ## fake images [batch_size, 3, 64, 64]

        DGz1 = D(Gz.detach(), real_y).view(-1) ## [128, 1, 1, 1] --> [128]
        d_fake_loss = criterion(DGz1, fake_labels) 
        d_fake_loss.backward()
        D_fake_accuracies_before_update.append((DGz1 < 0.5).float().mean().item())
        
        d_loss = d_real_loss + d_fake_loss
        d_optimizer.step()

        G.zero_grad()
        DGz2 = D(Gz, real_y).view(-1) ## [128, 1, 1, 1] --> [128]
        g_loss = criterion(DGz2, real_labels)
        g_loss.backward()
        D_fake_accuracies_after_update.append((DGz2 > 0.5).float().mean().item())
        g_optimizer.step()

        D_losses.append(d_loss.item())
        G_losses.append(g_loss.item())

    avg_metrics = {
        'D_loss': sum(D_losses) / len(D_losses),
        'G_loss': sum(G_losses) / len(G_losses),
        'D_real_acc': sum(D_real_accuracies) / len(D_real_accuracies),
        'D_fake_acc_before': sum(D_fake_accuracies_before_update) / len(D_fake_accuracies_before_update),
        'D_fake_acc_after': sum(D_fake_accuracies_after_update) / len(D_fake_accuracies_after_update),
    }
    return avg_metrics

In [11]:
def save_fake_images(epoch, G, fixed_noise, target_labels, save_dir, num_images=64):
    with torch.no_grad():
        # 각 클래스별로 이미지 생성
        fake_images = G(fixed_noise, target_labels).detach().cpu()
        
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"Fake Images at Epoch {epoch}")
    
    # 생성된 이미지를 그리드에 배치
    grid = vutils.make_grid(fake_images, nrow=images_per_class, padding=2, normalize=True)
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    
    plt.savefig(f"{save_dir}/Epoch_{epoch}_Fake.png")  # 이미지 파일로 저장
    plt.close(fig)

# ================================================================== #
#                    Model, Optimizer, Cost func                     #
# ================================================================== #
G = Generator(n_classes, z_dim, image_size, g_dim).to(device)
D = Discriminator(n_classes, x_dim, d_dim).to(device)
D.apply(weights_init)
G.apply(weights_init)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))

# ================================================================== #
#                       Training Iterations                          #
# ================================================================== #
n_sample = 100
images_per_class = n_sample // n_classes  # 각 클래스당 이미지 수
fixed_noise = torch.randn(n_sample, z_dim, 1, 1, device=device)
target_labels = torch.tensor([num for num in range(n_classes) for _ in range(images_per_class)]).to(device)

for epoch in range(epochs):
    print(f"Epoch [{epoch+1}/{epochs}]")
    metrics = train(D, G, train_dataloader, d_optimizer, g_optimizer, criterion, device)
    print(f'Discriminator Loss: {metrics["D_loss"]:.4f}')
    print(f'Generator Loss: {metrics["G_loss"]:.4f}')
    print(f'Discriminator Real Accuracy: {metrics["D_real_acc"]:.4f}')
    print(f'Discriminator Fake Accuracy (Before G Update): {metrics["D_fake_acc_before"]:.4f}')
    print(f'Discriminator Fake Accuracy (After G Update): {metrics["D_fake_acc_after"]:.4f}\n') ##  판별자가 가짜 이미지를 "진짜"로 잘못 분류한 점수.

    save_fake_images(epoch+1, G, fixed_noise, target_labels, save_dir, num_images=n_sample)

# Save the model checkpoints 
torch.save(G.state_dict(), f'{save_dir}/G.ckpt')
torch.save(D.state_dict(), f'{save_dir}/D.ckpt')

Epoch [1/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 43.35it/s]


Discriminator Loss: 1.0770
Generator Loss: 1.9969
Discriminator Real Accuracy: 0.7307
Discriminator Fake Accuracy (Before G Update): 0.7453
Discriminator Fake Accuracy (After G Update): 0.1240

Epoch [2/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.65it/s]


Discriminator Loss: 1.3191
Generator Loss: 1.1071
Discriminator Real Accuracy: 0.6041
Discriminator Fake Accuracy (Before G Update): 0.6189
Discriminator Fake Accuracy (After G Update): 0.2249

Epoch [3/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.94it/s]


Discriminator Loss: 1.3621
Generator Loss: 0.9186
Discriminator Real Accuracy: 0.5698
Discriminator Fake Accuracy (Before G Update): 0.5799
Discriminator Fake Accuracy (After G Update): 0.2554

Epoch [4/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.72it/s]


Discriminator Loss: 1.3740
Generator Loss: 0.8737
Discriminator Real Accuracy: 0.5575
Discriminator Fake Accuracy (Before G Update): 0.5633
Discriminator Fake Accuracy (After G Update): 0.2706

Epoch [5/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.06it/s]


Discriminator Loss: 1.3724
Generator Loss: 0.8504
Discriminator Real Accuracy: 0.5584
Discriminator Fake Accuracy (Before G Update): 0.5633
Discriminator Fake Accuracy (After G Update): 0.2891

Epoch [6/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.37it/s]


Discriminator Loss: 1.3688
Generator Loss: 0.8556
Discriminator Real Accuracy: 0.5605
Discriminator Fake Accuracy (Before G Update): 0.5686
Discriminator Fake Accuracy (After G Update): 0.2988

Epoch [7/20]


Train: 100%|██████████| 468/468 [00:09<00:00, 47.00it/s]


Discriminator Loss: 1.3526
Generator Loss: 0.9009
Discriminator Real Accuracy: 0.5845
Discriminator Fake Accuracy (Before G Update): 0.5857
Discriminator Fake Accuracy (After G Update): 0.2749

Epoch [8/20]


Train: 100%|██████████| 468/468 [00:09<00:00, 46.88it/s]


Discriminator Loss: 1.2923
Generator Loss: 1.1142
Discriminator Real Accuracy: 0.6307
Discriminator Fake Accuracy (Before G Update): 0.6318
Discriminator Fake Accuracy (After G Update): 0.2012

Epoch [9/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.33it/s]


Discriminator Loss: 1.0238
Generator Loss: 1.9405
Discriminator Real Accuracy: 0.7519
Discriminator Fake Accuracy (Before G Update): 0.7519
Discriminator Fake Accuracy (After G Update): 0.0973

Epoch [10/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.02it/s]


Discriminator Loss: 0.6935
Generator Loss: 2.9187
Discriminator Real Accuracy: 0.8496
Discriminator Fake Accuracy (Before G Update): 0.8494
Discriminator Fake Accuracy (After G Update): 0.0621

Epoch [11/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.99it/s]


Discriminator Loss: 0.5078
Generator Loss: 3.6397
Discriminator Real Accuracy: 0.9030
Discriminator Fake Accuracy (Before G Update): 0.9084
Discriminator Fake Accuracy (After G Update): 0.0494

Epoch [12/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.03it/s]


Discriminator Loss: 0.4637
Generator Loss: 3.5809
Discriminator Real Accuracy: 0.9155
Discriminator Fake Accuracy (Before G Update): 0.9202
Discriminator Fake Accuracy (After G Update): 0.0528

Epoch [13/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.14it/s]


Discriminator Loss: 0.4329
Generator Loss: 3.8121
Discriminator Real Accuracy: 0.9211
Discriminator Fake Accuracy (Before G Update): 0.9254
Discriminator Fake Accuracy (After G Update): 0.0515

Epoch [14/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.75it/s]


Discriminator Loss: 0.3521
Generator Loss: 4.0524
Discriminator Real Accuracy: 0.9406
Discriminator Fake Accuracy (Before G Update): 0.9435
Discriminator Fake Accuracy (After G Update): 0.0389

Epoch [15/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.15it/s]


Discriminator Loss: 0.3661
Generator Loss: 4.0900
Discriminator Real Accuracy: 0.9379
Discriminator Fake Accuracy (Before G Update): 0.9423
Discriminator Fake Accuracy (After G Update): 0.0405

Epoch [16/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.99it/s]


Discriminator Loss: 0.3948
Generator Loss: 4.0604
Discriminator Real Accuracy: 0.9329
Discriminator Fake Accuracy (Before G Update): 0.9358
Discriminator Fake Accuracy (After G Update): 0.0469

Epoch [17/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.02it/s]


Discriminator Loss: 0.3379
Generator Loss: 4.1110
Discriminator Real Accuracy: 0.9406
Discriminator Fake Accuracy (Before G Update): 0.9434
Discriminator Fake Accuracy (After G Update): 0.0378

Epoch [18/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 46.03it/s]


Discriminator Loss: 0.3793
Generator Loss: 4.0923
Discriminator Real Accuracy: 0.9321
Discriminator Fake Accuracy (Before G Update): 0.9360
Discriminator Fake Accuracy (After G Update): 0.0442

Epoch [19/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.99it/s]


Discriminator Loss: 0.4343
Generator Loss: 3.8855
Discriminator Real Accuracy: 0.9254
Discriminator Fake Accuracy (Before G Update): 0.9302
Discriminator Fake Accuracy (After G Update): 0.0479

Epoch [20/20]


Train: 100%|██████████| 468/468 [00:10<00:00, 45.84it/s]


Discriminator Loss: 0.4017
Generator Loss: 4.0386
Discriminator Real Accuracy: 0.9302
Discriminator Fake Accuracy (Before G Update): 0.9331
Discriminator Fake Accuracy (After G Update): 0.0465

