### GAD 구현

In [None]:
import torch
import torch.nn as nn

# 이미지 크기
IMAGE_SIZE = 64
Z_DIM = 100

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim=Z_DIM):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256*8*8),
            nn.BatchNorm1d(256*8*8),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # (16x16)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),   # (32x32)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),     # (64x64)
            nn.Tanh()  # [-1, 1]로 정규화
        )

    def forward(self, z):
        return self.model(z)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),   # (32x32)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # (16x16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # (8x8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(256*8*8, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

### 학습 루프

In [None]:
# 예시 학습 루프 요약
for epoch in range(num_epochs):
    for real_imgs, _ in dataloader:
        # ---------------------
        # 1. Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        real_validity = D(real_imgs)
        z = torch.randn(batch_size, Z_DIM).to(device)
        fake_imgs = G(z)
        fake_validity = D(fake_imgs.detach())
        d_loss = BCE(real_validity, torch.ones_like(real_validity)) + \
                 BCE(fake_validity, torch.zeros_like(fake_validity))
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        # 2. Train Generator
        # ---------------------
        optimizer_G.zero_grad()
        fake_validity = D(fake_imgs)
        g_loss = BCE(fake_validity, torch.ones_like(fake_validity))  # Generator wants D(G(z)) = 1
        g_loss.backward()
        optimizer_G.step()
