In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
from models import build_generator, build_discriminator
from dataset import get_dataset
import os

os.makedirs("results", exist_ok=True)

# ----------------------------
# 0️⃣ Hyperparameters
# ----------------------------
lambda_cyc, lambda_id = 5.0, 2.5
img_size = 256
batch_size = 2

# ----------------------------
# 1️⃣ Generator & Discriminator
# ----------------------------
G = build_generator(image_size=img_size, n_blocks=9)
F = build_generator(image_size=img_size, n_blocks=9)
D_X = build_discriminator(image_size=img_size)
D_Y = build_discriminator(image_size=img_size)

# ----------------------------
# 2️⃣ Optimizer & Loss
# ----------------------------
G_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_X_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_Y_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

mse = tf.keras.losses.MeanSquaredError()
L1 = lambda a, b: tf.reduce_mean(tf.abs(a - b))

# ----------------------------
# 3️⃣ 1 training step
# ----------------------------
@tf.function
def train_step(real_x, real_y, lambda_cyc, lambda_id):
    with tf.GradientTape(persistent=True) as tape:

        fake_y = G(real_x, training=True)
        fake_x = F(real_y, training=True)

        cyc_x = F(fake_y, training=True)
        cyc_y = G(fake_x, training=True)

        same_x = F(real_x, training=True)
        same_y = G(real_y, training=True)

        D_X_real = D_X(real_x, training=True)
        D_X_fake = D_X(fake_x, training=True)
        D_Y_real = D_Y(real_y, training=True)
        D_Y_fake = D_Y(fake_y, training=True)

        G_GAN_loss = mse(tf.ones_like(D_Y_fake), D_Y_fake)
        F_GAN_loss = mse(tf.ones_like(D_X_fake), D_X_fake)

        D_X_loss = 0.5 * (mse(tf.ones_like(D_X_real), D_X_real) + mse(tf.zeros_like(D_X_fake), D_X_fake))
        D_Y_loss = 0.5 * (mse(tf.ones_like(D_Y_real), D_Y_real) + mse(tf.zeros_like(D_Y_fake), D_Y_fake))

        cycle_loss = L1(cyc_x, real_x) + L1(cyc_y, real_y)
        id_loss = L1(same_x, real_x) + L1(same_y, real_y)

        G_total_loss = G_GAN_loss + F_GAN_loss + lambda_cyc * cycle_loss + lambda_id * id_loss

    G_grads = tape.gradient(G_total_loss, G.trainable_variables + F.trainable_variables)
    D_X_grads = tape.gradient(D_X_loss, D_X.trainable_variables)
    D_Y_grads = tape.gradient(D_Y_loss, D_Y.trainable_variables)

    del tape
    G_optimizer.apply_gradients(zip(G_grads, G.trainable_variables + F.trainable_variables))
    D_X_optimizer.apply_gradients(zip(D_X_grads, D_X.trainable_variables))
    D_Y_optimizer.apply_gradients(zip(D_Y_grads, D_Y.trainable_variables))

    return G_total_loss, D_X_loss, D_Y_loss

# ----------------------------
# 4️⃣ Dataset
# ----------------------------
trainX = get_dataset('data/butterfly_real/*.jpeg', batch_size=batch_size, img_size=img_size)
trainY = get_dataset('data/butterfly_origami/*.jpg', batch_size=batch_size, img_size=img_size)

# ----------------------------
# 5️⃣ Training cycle
# ----------------------------
for epoch in range(100):
    for real_x, real_y in zip(trainX, trainY):
        G_loss, DX_loss, DY_loss = train_step(real_x, real_y, lambda_cyc, lambda_id)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: G={G_loss:.3f} DX={DX_loss:.3f} DY={DY_loss:.3f}")

        sample = next(iter(trainX))
        fake_y = G(sample, training=False)
        plt.figure(figsize=(6, 3))
        plt.subplot(1, 2, 1); plt.imshow((sample[0] + 1) / 2); plt.title("Real X"); plt.axis("off")
        plt.subplot(1, 2, 2); plt.imshow((fake_y[0] + 1) / 2); plt.title("Fake Y"); plt.axis("off")
        plt.tight_layout()
        plt.savefig(f"results/epoch_{epoch+1}.png")
        plt.close()
        
print("\nGenerating final comparison on test image...")

test_path = "data/NST_butterfly.jpg"
if os.path.exists(test_path):
    test_img = tf.io.read_file(test_path)
    test_img = tf.image.decode_jpeg(test_img, channels=3)
    test_img = tf.image.resize(test_img, [img_size, img_size])
    test_img = (tf.cast(test_img, tf.float32) / 127.5) - 1.0
    test_img = tf.expand_dims(test_img, 0)

    fake_y = G(test_img, training=False)
    fake_y = (fake_y[0] + 1.0) / 2.0

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1); plt.imshow((test_img[0] + 1) / 2.0); plt.title("Original Image"); plt.axis("off")
    plt.subplot(1, 2, 2); plt.imshow(fake_y); plt.title("CycleGAN Output"); plt.axis("off")
    plt.tight_layout()
    plt.savefig("results/final_comparison.png")
    plt.close()
    print("✅ Final comparison saved at results/final_comparison.png")
else:
    print("⚠️ Skipped final comparison: test image not found at data/single_eval/test_img.jpg")


Epoch 10: G=3.620 DX=0.173 DY=0.156
Epoch 20: G=3.558 DX=0.259 DY=0.202
Epoch 30: G=3.558 DX=0.120 DY=0.081
Epoch 40: G=2.994 DX=0.126 DY=0.175
Epoch 50: G=2.991 DX=0.362 DY=0.369
Epoch 60: G=3.043 DX=0.226 DY=0.178
Epoch 70: G=4.628 DX=0.163 DY=0.095
Epoch 80: G=3.682 DX=0.180 DY=0.119
Epoch 90: G=3.271 DX=0.242 DY=0.631
Epoch 100: G=4.436 DX=0.148 DY=0.191

Generating final comparison on test image...
✅ Final comparison saved at results/final_comparison.png
