### 모듈, dataset 다운로드

In [None]:
import os
import numpy as np
from PIL import Image
from sklearn.preprocessing import OneHotEncoder
from tensorflow.keras.utils import Sequence
import kagglehub

dataset_path = kagglehub.dataset_download("jangedoo/utkface-new")
image_dir = os.path.join(dataset_path, "UTKFace")

### hyperparameter, 기본 설정

In [None]:
IMG_SIZE = 64
MAX_AGE = 116
BATCH_SIZE = 64
LATENT_DIM = 100
EPOCHS = 20000
SAVE_INTERVAL = 1000

### 파일 이름에서 나이 추출

In [None]:
def extract_age(filename):
    try:
        age = int(filename.split('_')[0])
        if 1 <= age <= 116:
            return age - 1  # shift range to 0-115
        else:
            return None
    except:
        return None

### 이미지 불러오기, 나이 설정

In [None]:
image_paths = []
ages = []

for fname in os.listdir(image_dir):
    if fname.endswith('.jpg'):
        age = extract_age(fname)
        if age is not None:
            image_paths.append(os.path.join(image_dir, fname))
            ages.append(age)

ages = np.array(ages).reshape(-1, 1)
encoder = OneHotEncoder(sparse_output=False, categories=[np.arange(MAX_AGE)])
age_onehot = encoder.fit_transform(ages)

### dataset 배치 단위로 분할

In [None]:
class UTKFaceDataGenerator(Sequence):
    def __init__(self, image_paths, age_labels, batch_size, img_size):
        self.image_paths = image_paths
        self.age_labels = age_labels
        self.batch_size = batch_size
        self.img_size = img_size
        self.indexes = np.arange(len(self.image_paths))
    
    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, index):
        # Get batch indexes
        batch_idx = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_images = []
        batch_conditions = []

        for i in batch_idx:
            img = Image.open(self.image_paths[i]).convert('RGB')
            img = img.resize((self.img_size, self.img_size))
            img = np.array(img) / 127.5 - 1.0  # Normalize to [-1, 1]
            batch_images.append(img)
            batch_conditions.append(self.age_labels[i])

        batch_images = np.array(batch_images)
        batch_conditions = np.array(batch_conditions)

        return [np.array(batch_images), np.array(batch_conditions)]

train_generator = UTKFaceDataGenerator(image_paths, age_onehot, BATCH_SIZE, IMG_SIZE)

### Discriminator, Generator 정의

In [None]:
from tensorflow.keras import layers, models

def Discriminator(img_shape, condition_dim):
    input_img = layers.Input(shape=img_shape)
    input_cond = layers.Input(shape=(condition_dim,))
    
    flat_img = layers.Flatten()(input_img)
    merged = layers.Concatenate()([flat_img, input_cond])
    
    x = layers.Dense(256)(merged)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dense(256)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    
    return models.Model([input_img, input_cond], x)

def Generator(latent_dim, condition_dim, img_shape):
    input_noise = layers.Input(shape=(latent_dim,))
    input_cond = layers.Input(shape=(condition_dim,))
    
    merged = layers.Concatenate()([input_noise, input_cond])
    
    x = layers.Dense(256)(merged)
    x = layers.ReLU()(x)
    x = layers.Dense(256)(x)
    x = layers.ReLU()(x)
    x = layers.Dense(np.prod(img_shape), activation='tanh')(x)
    output_img = layers.Reshape(img_shape)(x)
    
    return models.Model([input_noise, input_cond], output_img)

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
AGE_CLASSES = MAX_AGE
D = Discriminator(IMG_SHAPE, AGE_CLASSES)
G = Generator(LATENT_DIM, AGE_CLASSES, IMG_SHAPE)

### loss, optimizer 설정

In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

bce = BinaryCrossentropy()

d_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
g_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)

### 학습 및 샘플 생성

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

D.trainable = False

# 학습 루프 정의
@tf.function
def train_step(real_images, age_conditions):
    batch_size = tf.shape(real_images)[0]
    
    # ---------------------
    #  Train Discriminator
    # ---------------------
    noise = tf.random.normal([batch_size, LATENT_DIM])
    fake_images = G([noise, age_conditions], training=True)

    real_labels = tf.ones((batch_size, 1))
    fake_labels = tf.zeros((batch_size, 1))

    D.trainable = True
    with tf.GradientTape() as tape:
        d_real = D([real_images, age_conditions], training=True)
        d_fake = D([fake_images, age_conditions], training=True)
        d_loss_real = bce(real_labels, d_real)
        d_loss_fake = bce(fake_labels, d_fake)
        d_loss = d_loss_real + d_loss_fake

    grads = tape.gradient(d_loss, D.trainable_variables)
    d_optimizer.apply_gradients(zip(grads, D.trainable_variables))

    # ---------------------
    #  Train Generator
    # ---------------------
    D.trainable = False
    noise = tf.random.normal([batch_size, LATENT_DIM])
    misleading_labels = tf.ones((batch_size, 1))

    with tf.GradientTape() as tape:
        generated_images = G([noise, age_conditions], training=True)
        d_generated = D([generated_images, age_conditions], training=True)
        g_loss = bce(misleading_labels, d_generated)

    grads = tape.gradient(g_loss, G.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, G.trainable_variables))

    return d_loss, g_loss

# 샘플 생성 함수
def generate_sample_images(epoch, sample_dir="samples"):
    os.makedirs(sample_dir, exist_ok=True)
    
    r, c = 4, 4  # 그릴 샘플 수 (16개)
    noise = tf.random.normal([r * c, LATENT_DIM])
    
    # 10세, 20세, ..., 160세로 설정 (one-hot encoding)
    sample_ages = np.linspace(0, AGE_CLASSES - 1, r * c, dtype=int).reshape(-1, 1)
    age_conditions = encoder.transform(sample_ages)

    gen_imgs = G.predict([noise, age_conditions])
    gen_imgs = 0.5 * gen_imgs + 0.5  # [-1, 1] -> [0, 1]

    fig, axs = plt.subplots(r, c, figsize=(10, 10))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt])
            axs[i, j].axis('off')
            axs[i, j].set_title(f"Age: {sample_ages[cnt][0]+1}")
            cnt += 1
    fig.savefig(f"{sample_dir}/generated_{epoch}.png")
    plt.close()

# 학습 루프 실행
steps_per_epoch = len(train_generator)

for epoch in range(1, EPOCHS + 1):
    step = 0
    for batch_imgs, batch_conds in train_generator:
        if batch_imgs.shape[0] == 0:
            continue
        d_loss, g_loss = train_step(batch_imgs, batch_conds)
        step += 1

        if step % 125 == 0 or step == steps_per_epoch:
            d_real = D([batch_imgs, batch_conds], training=False)
            noise = tf.random.normal([batch_imgs.shape[0], LATENT_DIM])
            fake_imgs = G([noise, tf.convert_to_tensor(batch_conds)], training=False)
            d_fake = D([fake_imgs, tf.convert_to_tensor(batch_conds)], training=False)

            print(f"Epoch [{epoch}/{EPOCHS}], Step [{step}/{steps_per_epoch}], "
                  f"d_loss: {d_loss:.4f}, g_loss: {g_loss:.4f}, "
                  f"D(x): {tf.reduce_mean(d_real):.2f}, D(G(z)): {tf.reduce_mean(d_fake):.2f}")


    if epoch % SAVE_INTERVAL == 0:
        generate_sample_images(epoch)

### 나이별 이미지 생성

In [None]:
def generate_images_grid_by_ages(target_ages=[5, 15, 25, 40, 60], samples_per_age=7):
    total_images = len(target_ages) * samples_per_age
    noise = tf.random.normal([total_images, LATENT_DIM])

    # One-hot 조건 벡터 준비
    repeated_ages = np.repeat(np.array(target_ages), samples_per_age).reshape(-1, 1)
    age_conditions = encoder.transform(repeated_ages)

    # 이미지 생성
    gen_imgs = G.predict([noise, age_conditions], verbose=0)
    gen_imgs = 0.5 * gen_imgs + 0.5  # [-1, 1] → [0, 1]

    # 출력
    fig, axs = plt.subplots(len(target_ages), samples_per_age, figsize=(samples_per_age * 2, len(target_ages) * 2))
    cnt = 0
    for row in range(len(target_ages)):
        for col in range(samples_per_age):
            axs[row, col].imshow(gen_imgs[cnt])
            axs[row, col].axis('off')
            if col == 0:
                axs[row, col].set_ylabel(f'Age {target_ages[row]}', fontsize=12)
            cnt += 1

    plt.tight_layout()
    plt.show()

generate_images_grid_by_ages()