In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
from models import build_generator, build_discriminator
from dataset import get_dataset
import os
os.makedirs("results", exist_ok=True)

#lambda
def update_hyperparams(epoch):
    #if epoch < 10:
    #    return 5.0, 0.5 * 5.0
    #elif epoch < 30:
    #    return 7.5, 0.5 * 7.5
    #else:
    #    return 10.0, 0.5 * 10.0
    return 3.0, 0.5 * 3.0

# ----------------------------
# 1️⃣ Build the generator & discriminator
# ----------------------------
G = build_generator()   # X → Y
F = build_generator()   # Y → X
D_X = build_discriminator()  # 判别 X 域
D_Y = build_discriminator()  # 判别 Y 域

# ----------------------------
# 2️⃣ optimizer and loss
# ----------------------------
G_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_X_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_Y_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

# LSGAN: MSE Loss (Mean Squared Error)
mse = tf.keras.losses.MeanSquaredError()
L1 = lambda a, b: tf.reduce_mean(tf.abs(a - b))

# ----------------------------
# 3️⃣ training step
# ----------------------------
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        # 生成假样本
        fake_y = G(real_x, training=True)
        fake_x = F(real_y, training=True)

        # 循环一致性重建
        cyc_x = F(fake_y, training=True)
        cyc_y = G(fake_x, training=True)

        # 身份映射
        same_x = F(real_x, training=True)
        same_y = G(real_y, training=True)

        # 判别器输出 (logits)
        D_X_real = D_X(real_x, training=True)
        D_X_fake = D_X(fake_x, training=True)
        D_Y_real = D_Y(real_y, training=True)
        D_Y_fake = D_Y(fake_y, training=True)

        # ---- LSGAN 损失 ----
        # 真实样本 → label=1，假样本 → label=0
        G_GAN_loss = mse(tf.ones_like(D_Y_fake), D_Y_fake)
        F_GAN_loss = mse(tf.ones_like(D_X_fake), D_X_fake)

        # 判别器 loss
        D_X_loss = 0.5 * (mse(tf.ones_like(D_X_real), D_X_real) +
                          mse(tf.zeros_like(D_X_fake), D_X_fake))
        D_Y_loss = 0.5 * (mse(tf.ones_like(D_Y_real), D_Y_real) +
                          mse(tf.zeros_like(D_Y_fake), D_Y_fake))

        # 循环与身份损失
        cycle_loss = L1(cyc_x, real_x) + L1(cyc_y, real_y)
        id_loss = L1(same_x, real_x) + L1(same_y, real_y)

        # 生成器总损失
        G_total_loss = G_GAN_loss + F_GAN_loss + lambda_cyc * cycle_loss + lambda_id * id_loss

    # 梯度与优化
    G_grads = tape.gradient(G_total_loss, G.trainable_variables + F.trainable_variables)
    D_X_grads = tape.gradient(D_X_loss, D_X.trainable_variables)
    D_Y_grads = tape.gradient(D_Y_loss, D_Y.trainable_variables)

    G_optimizer.apply_gradients(zip(G_grads, G.trainable_variables + F.trainable_variables))
    D_X_optimizer.apply_gradients(zip(D_X_grads, D_X.trainable_variables))
    D_Y_optimizer.apply_gradients(zip(D_Y_grads, D_Y.trainable_variables))

    del tape
    return G_total_loss, D_X_loss, D_Y_loss

# ----------------------------
# 4️⃣ training cycle
# ----------------------------
trainX = get_dataset('data/butterfly_real/*.jpeg', batch_size=2)
trainY = get_dataset('data/butterfly_origami/*.jpg', batch_size=2)

for epoch in range(70):
    lambda_cyc, lambda_id = update_hyperparams(epoch)
    for real_x, real_y in zip(trainX, trainY):
        G_loss, DX_loss, DY_loss = train_step(real_x, real_y)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: G={G_loss:.3f} DX={DX_loss:.3f} DY={DY_loss:.3f}")
    
        # save each 10 epoches
        sample = next(iter(trainX))
        fake_y = G(sample, training=False)
        plt.figure(figsize=(6,3))
        plt.subplot(1,2,1); plt.imshow((sample[0]+1)/2); plt.title("Real X"); plt.axis("off")
        plt.subplot(1,2,2); plt.imshow((fake_y[0]+1)/2); plt.title("Fake Y"); plt.axis("off")
        plt.tight_layout()
        plt.savefig(f"results/epoch_{epoch+1}.png")
        plt.close()


Epoch 10: G=2.896 DX=0.217 DY=0.198
Epoch 20: G=3.289 DX=0.198 DY=0.120
Epoch 30: G=2.659 DX=0.248 DY=0.201
Epoch 40: G=2.679 DX=0.204 DY=0.271
Epoch 50: G=3.024 DX=0.151 DY=0.126
Epoch 60: G=3.566 DX=0.171 DY=0.106
Epoch 70: G=3.178 DX=0.078 DY=0.191
