In [4]:
import tensorflow as tf
import matplotlib.pyplot as plt
from models import build_generator, build_discriminator
from dataset import get_dataset
import os
os.makedirs("results", exist_ok=True)

#调参
def update_hyperparams(epoch):
    #if epoch < 30:
    return 3.0, 0.5 * 3.0
    #elif epoch < 50:
    #    return 5.0, 0.5 * 5.0
    #else:
    #    return 7.5, 0.5 * 7.5

# ----------------------------
# 1️⃣ Build the generator & discriminator
# ----------------------------
G = build_generator()   # X → Y
F = build_generator()   # Y → X
D_X = build_discriminator()  # 判别 X 域
D_Y = build_discriminator()  # 判别 Y 域

# ----------------------------
# 2️⃣ Optimizer & Loss
# ----------------------------
G_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
D_X_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_Y_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

# LSGAN: MSE Loss (Mean Squared Error)
mse = tf.keras.losses.MeanSquaredError()
L1 = lambda a, b: tf.reduce_mean(tf.abs(a - b))
lambda_cyc = 10.0
lambda_id = 0.5 * lambda_cyc

# ----------------------------
# 3️⃣ Training_step
# ----------------------------
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        # 生成假样本
        fake_y = G(real_x, training=True)
        fake_x = F(real_y, training=True)

        # 循环一致性重建
        cyc_x = F(fake_y, training=True)
        cyc_y = G(fake_x, training=True)

        # 身份映射
        same_x = F(real_x, training=True)
        same_y = G(real_y, training=True)

        # 判别器输出 (logits)
        D_X_real = D_X(real_x, training=True)
        D_X_fake = D_X(fake_x, training=True)
        D_Y_real = D_Y(real_y, training=True)
        D_Y_fake = D_Y(fake_y, training=True)

        # ---- LSGAN 损失 ----
        # 真实样本 → label=1，假样本 → label=0
        G_GAN_loss = mse(tf.ones_like(D_Y_fake), D_Y_fake)
        F_GAN_loss = mse(tf.ones_like(D_X_fake), D_X_fake)

        # 判别器 loss
        D_X_loss = 0.5 * (mse(tf.ones_like(D_X_real), D_X_real) +
                          mse(tf.zeros_like(D_X_fake), D_X_fake))
        D_Y_loss = 0.5 * (mse(tf.ones_like(D_Y_real), D_Y_real) +
                          mse(tf.zeros_like(D_Y_fake), D_Y_fake))

        # 循环与身份损失
        cycle_loss = L1(cyc_x, real_x) + L1(cyc_y, real_y)
        id_loss = L1(same_x, real_x) + L1(same_y, real_y)

        # 生成器总损失
        G_total_loss = G_GAN_loss + F_GAN_loss + lambda_cyc * cycle_loss + lambda_id * id_loss

    # 梯度与优化
    G_grads = tape.gradient(G_total_loss, G.trainable_variables + F.trainable_variables)
    D_X_grads = tape.gradient(D_X_loss, D_X.trainable_variables)
    D_Y_grads = tape.gradient(D_Y_loss, D_Y.trainable_variables)

    G_optimizer.apply_gradients(zip(G_grads, G.trainable_variables + F.trainable_variables))
    D_X_optimizer.apply_gradients(zip(D_X_grads, D_X.trainable_variables))
    D_Y_optimizer.apply_gradients(zip(D_Y_grads, D_Y.trainable_variables))

    del tape
    return G_total_loss, D_X_loss, D_Y_loss

# ----------------------------
# 4️⃣ Training_cycle
# ----------------------------
trainX = get_dataset('data/butterfly_real/*.jpeg', batch_size=2)
trainY = get_dataset('data/butterfly_origami/*.jpg', batch_size=2)

for epoch in range(70):
    lambda_cyc, lambda_id = update_hyperparams(epoch)
    for real_x, real_y in zip(trainX, trainY):
        G_loss, DX_loss, DY_loss = train_step(real_x, real_y)
        print(f"Epoch {epoch+1}: G={G_loss:.3f} DX={DX_loss:.3f} DY={DY_loss:.3f}")

    # 每个 epoch 保存结果图
    sample = next(iter(trainX))
    fake_y = G(sample, training=False)
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1); plt.imshow((sample[0]+1)/2); plt.title("Real X"); plt.axis("off")
    plt.subplot(1,2,2); plt.imshow((fake_y[0]+1)/2); plt.title("Fake Y"); plt.axis("off")
    plt.tight_layout()
    plt.savefig(f"results/epoch_{epoch+1}.png")
    plt.close()


Epoch 1: G=6.161 DX=0.209 DY=0.213
Epoch 2: G=4.620 DX=0.247 DY=0.198
Epoch 3: G=4.527 DX=0.199 DY=0.200
Epoch 4: G=6.442 DX=0.188 DY=0.144
Epoch 5: G=5.027 DX=0.149 DY=0.342
Epoch 6: G=6.989 DX=0.117 DY=0.180
Epoch 7: G=5.111 DX=0.212 DY=0.113
Epoch 8: G=4.833 DX=0.195 DY=0.362
Epoch 9: G=4.976 DX=0.244 DY=0.132
Epoch 10: G=5.630 DX=0.174 DY=0.091
Epoch 11: G=7.903 DX=0.159 DY=0.325
Epoch 12: G=5.425 DX=0.188 DY=0.096
Epoch 13: G=5.317 DX=0.245 DY=0.070
Epoch 14: G=5.840 DX=0.112 DY=0.144
Epoch 15: G=5.874 DX=0.294 DY=0.071
Epoch 16: G=4.677 DX=0.148 DY=0.101
Epoch 17: G=4.699 DX=0.368 DY=0.241
Epoch 18: G=4.422 DX=0.078 DY=0.026
Epoch 19: G=5.261 DX=0.113 DY=0.242
Epoch 20: G=4.647 DX=0.030 DY=0.171
Epoch 21: G=5.985 DX=0.156 DY=0.190
Epoch 22: G=6.648 DX=0.109 DY=0.073
Epoch 23: G=4.961 DX=0.057 DY=0.235
Epoch 24: G=5.424 DX=0.070 DY=0.044
Epoch 25: G=6.706 DX=0.078 DY=0.148
Epoch 26: G=4.069 DX=0.186 DY=0.020
Epoch 27: G=4.732 DX=0.226 DY=0.093
Epoch 28: G=5.570 DX=0.161 DY=0.139
E