### CGAD 구현

In [None]:
import torch
import torch.nn as nn

class ConditionalGenerator(nn.Module):
    def __init__(self, z_dim=100, age_dim=116):
        super().__init__()
        self.label_emb = nn.Linear(age_dim, 50)  # age one-hot → 50차원

        self.model = nn.Sequential(
            nn.Linear(z_dim + 50, 256 * 8 * 8),
            nn.BatchNorm1d(256 * 8 * 8),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, age_onehot):
        label_embedding = self.label_emb(age_onehot)
        x = torch.cat([z, label_embedding], dim=1)
        return self.model(x)

class ConditionalDiscriminator(nn.Module):
    def __init__(self, age_dim=116):
        super().__init__()
        self.label_emb = nn.Linear(age_dim, IMAGE_SIZE * IMAGE_SIZE)  # 조건을 (64x64)로 broadcast

        self.model = nn.Sequential(
            nn.Conv2d(4, 64, 4, 2, 1),  # 3채널 + 1조건 → 4채널
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img, age_onehot):
        label = self.label_emb(age_onehot).view(-1, 1, IMAGE_SIZE, IMAGE_SIZE)
        x = torch.cat([img, label], dim=1)  # concat along channel
        return self.model(x)


### 학습 루프

In [None]:
for real_imgs, age_labels in dataloader:
    # 1. Age one-hot encoding
    age_onehot = F.one_hot(age_labels, num_classes=116).float().to(device)

    # 2. Train Discriminator
    optimizer_D.zero_grad()
    real_validity = D(real_imgs, age_onehot)
    z = torch.randn(batch_size, Z_DIM).to(device)
    fake_imgs = G(z, age_onehot)
    fake_validity = D(fake_imgs.detach(), age_onehot)
    d_loss = BCE(real_validity, torch.ones_like(real_validity)) + \
             BCE(fake_validity, torch.zeros_like(fake_validity))
    d_loss.backward()
    optimizer_D.step()

    # 3. Train Generator
    optimizer_G.zero_grad()
    fake_validity = D(fake_imgs, age_onehot)
    g_loss = BCE(fake_validity, torch.ones_like(fake_validity))
    g_loss.backward()
    optimizer_G.step()

### 시각화

In [None]:
# 특정 나이대의 얼굴 생성
target_age = 25
age_cond = torch.nn.functional.one_hot(torch.tensor([target_age]), num_classes=116).float().to(device)
z = torch.randn(1, Z_DIM).to(device)
fake_img = G(z, age_cond)
show_tensor_image(fake_img)