In [None]:
#from experiment_utils import ExperimentRun
import tensorflow as tf
import matplotlib.pyplot as plt
from models import build_generator, build_discriminator
from dataset import get_dataset, get_dataset_with_mask
import os
import time, sys
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input

In [None]:
lambda_cyc, lambda_id = 5.0, 2.5
img_size = 256
batch_size = 4
num_epochs = 100
alpha = 0.02

In [None]:
G = build_generator(image_size=img_size, in_channels=4, out_channels=3)
F = build_generator(image_size=img_size, in_channels=3, out_channels=3)
D_X = build_discriminator(image_size=img_size)
D_Y = build_discriminator(image_size=img_size)

In [None]:
G_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
F_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_X_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_Y_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

mse = tf.keras.losses.MeanSquaredError()
L1 = lambda a, b: tf.reduce_mean(tf.abs(a - b))

In [None]:
vgg = VGG19(include_top=False, weights="imagenet",
            input_shape=(img_size, img_size, 3))
vgg.trainable = False

# Use a mid-level conv layer as perceptual features
vgg_feature_layer = "block3_conv3"
vgg_feature_extractor = tf.keras.Model(
    inputs=vgg.input,
    outputs=vgg.get_layer(vgg_feature_layer).output
)

def perceptual_loss(img1, img2):
    """
    Perceptual L1 loss between VGG19 feature maps.
    img1, img2 are expected in [-1, 1] range, RGB.
    """
    # [-1, 1] -> [0, 255]
    img1 = (img1 + 1.0) * 127.5
    img2 = (img2 + 1.0) * 127.5

    # VGG19 preprocessing (BGR ordering, mean subtraction, etc.)
    img1 = preprocess_input(img1)
    img2 = preprocess_input(img2)

    # Extract feature maps
    feat1 = vgg_feature_extractor(img1)
    feat2 = vgg_feature_extractor(img2)

    # L1 in feature space
    return tf.reduce_mean(tf.abs(feat1 - feat2))


In [None]:
#expereiment setup
params = {
    "lambda_cyc": lambda_cyc,
    "lambda_id": lambda_id,
    "img_size": img_size,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "alpha": alpha,
    "G_lr": 1e-4,
    "D_lr": 1e-4,
    "n_blocks": 9,
}

config = {
    **params
}

exp = ExperimentRun(params=params, config=config)
logger = exp.get_logger()

# connect checkpoint manager
ckpt, ckpt_manager = exp.create_checkpoint_manager(
    G=G, F=F, D_X=D_X, D_Y=D_Y,
    G_optimizer=G_optimizer,
    F_optimizer=F_optimizer,
    D_X_optimizer=D_X_optimizer,
    D_Y_optimizer=D_Y_optimizer,
)

In [None]:
@tf.function
def train_step(real_x, real_y, lambda_cyc, lambda_id):

    # real_x: (B,H,W,4)  → RGB+mask
    # real_y: (B,H,W,3)

    real_x_rgb = real_x[..., :3]     # strip mask for losses
    mask_x     = real_x[..., 3:]     # (B,H,W,1)

    with tf.GradientTape(persistent=True) as tape:

        # -------------------------------------------------
        # 1. Forward translation
        # -------------------------------------------------
        fake_y = G(real_x, training=True)       # G: 4→3
        fake_x = F(real_y, training=True)       # F: 3→3

        # -------------------------------------------------
        # 2. Back cycle (IMPORTANT: G needs a mask)
        # -------------------------------------------------
        cyc_y_input = tf.concat([fake_x, mask_x], axis=-1)   # use original mask
        cyc_y = G(cyc_y_input, training=True)                # (3ch)

        cyc_x = F(fake_y, training=True)                     # (3ch)

        # -------------------------------------------------
        # 3. Identity loss (remove mask for F)
        # -------------------------------------------------
        same_y = G(tf.concat([real_y, mask_x], axis=-1), training=True)  # feed mask_x
        same_x = F(real_x_rgb, training=True)                            # RGB only

        # -------------------------------------------------
        # 4. Discriminator predictions (use RGB only)
        # -------------------------------------------------
        D_X_real = D_X(real_x_rgb, training=True)
        D_X_fake = D_X(fake_x, training=True)

        D_Y_real = D_Y(real_y, training=True)
        D_Y_fake = D_Y(fake_y, training=True)

        # -------------------------------------------------
        # 5. GAN losses: LSGAN
        # -------------------------------------------------
        G_GAN_loss = mse(tf.ones_like(D_Y_fake), D_Y_fake)
        F_GAN_loss = mse(tf.ones_like(D_X_fake), D_X_fake)

        D_X_loss = 0.5 * (mse(tf.ones_like(D_X_real), D_X_real) +
                          mse(tf.zeros_like(D_X_fake), D_X_fake))

        D_Y_loss = 0.5 * (mse(tf.ones_like(D_Y_real), D_Y_real) +
                          mse(tf.zeros_like(D_Y_fake), D_Y_fake))

        # -------------------------------------------------
        # 6. Cycle losses (pixel + perceptual)
        # -------------------------------------------------
        cycle_G_pixel = L1(cyc_y, real_y)
        cycle_F_pixel = L1(cyc_x, real_x_rgb)

        cycle_G_perc  = alpha * perceptual_loss(cyc_y, real_y)
        cycle_F_perc  = alpha * perceptual_loss(cyc_x, real_x_rgb)

        cycle_G = cycle_G_pixel + cycle_G_perc
        cycle_F = cycle_F_pixel + cycle_F_perc

        # -------------------------------------------------
        # 7. Identity loss (RGB only)
        # -------------------------------------------------
        id_G = L1(same_y, real_y)
        id_F = L1(same_x, real_x_rgb)

        # -------------------------------------------------
        # 8. Total generator losses (separate G and F)
        # -------------------------------------------------
        G_loss = G_GAN_loss + lambda_cyc * cycle_G + lambda_id * id_G
        F_loss = F_GAN_loss + lambda_cyc * cycle_F + lambda_id * id_F

    # -----------------------------------------------------
    # 9. Apply gradients
    # -----------------------------------------------------
    G_grads = tape.gradient(G_loss, G.trainable_variables)
    F_grads = tape.gradient(F_loss, F.trainable_variables)

    D_X_grads = tape.gradient(D_X_loss, D_X.trainable_variables)
    D_Y_grads = tape.gradient(D_Y_loss, D_Y.trainable_variables)

    del tape

    G_optimizer.apply_gradients(zip(G_grads, G.trainable_variables))
    F_optimizer.apply_gradients(zip(F_grads, F.trainable_variables))
    D_X_optimizer.apply_gradients(zip(D_X_grads, D_X.trainable_variables))
    D_Y_optimizer.apply_gradients(zip(D_Y_grads, D_Y.trainable_variables))

    return (
        G_loss, F_loss,
        D_X_loss, D_Y_loss,
        cycle_G_pixel, cycle_F_pixel,
        cycle_G_perc, cycle_F_perc
    )

In [None]:
trainX = trainX = get_dataset_with_mask(
    image_pattern='../../data/split/segmented/trainB/butterfly/*.jpg',
    mask_pattern='../../data/split/segmented/trainB/butterfly/*_mask.png',
    batch_size=batch_size,
    img_size=img_size)
trainY = get_dataset('../../data/split/origami/train/butterfly/*.png',
                     batch_size=batch_size,
                     img_size=img_size)

In [None]:
cardX = tf.data.experimental.cardinality(trainX).numpy()
cardY = tf.data.experimental.cardinality(trainY).numpy()
num_steps_per_epoch = int(min(cardX, cardY))
bar_len = 30
ckpt_interval = 5

for epoch in range(num_epochs):
    
    epoch_start = time.time() 
    
    logger.info(f"Epoch {epoch+1}/{num_epochs}")
    step = 0
    
    for real_x, real_y in zip(trainX, trainY):
        step += 1
        G_loss, F_loss, DX_loss, DY_loss, cycle_G_pixel, cycle_F_pixel, cycle_G_perc, cycle_F_perc= train_step(real_x, real_y, lambda_cyc, lambda_id)
        
        bar = exp.progress_bar(step, num_steps_per_epoch, bar_len=bar_len)
        sys.stdout.write(
            "\r"
            f"Epoch {epoch+1}/{num_epochs} "
            f"{bar}  step {step}/{num_steps_per_epoch}"
        )
        sys.stdout.flush()
        
        # if step % 50 == 0:
        #     logger.info(
        #         f"  step {step}: G={float(G_loss.numpy()):.3f} "
        #         f"DX={float(DX_loss.numpy()):.3f} DY={float(DY_loss.numpy()):.3f}"
        #     )
        loss_dict = {
            "G_loss": G_loss,
            "F_loss": F_loss,
            "D_X": DX_loss,
            "D_Y": DY_loss,
            "cycle_G_pixel": cycle_G_pixel,
            "cycle_F_pixel": cycle_F_pixel,
            "cycle_G_perc": cycle_G_perc,
            "cycle_F_perc": cycle_F_perc,
        }
        exp.log_losses(epoch + 1, step, loss_dict)
        
    sys.stdout.write("\n")
    sys.stdout.flush()
    
    epoch_time = time.time() - epoch_start
    time_per_step = epoch_time / step

    logger.info(
        f"Epoch {epoch+1} completed in {epoch_time:.2f}s "
        f"({time_per_step:.4f}s per step)"
    )
    
    if (epoch + 1) % ckpt_interval == 0 or (epoch + 1) == num_epochs:
        exp.save_checkpoint()

    sample = next(iter(trainX))
    fake_y = G(sample, training=False)

    if (epoch + 1) % 10 == 0:
        logger.info(
            "Epoch %d: G=%.3f F=%.3f DX=%.3f DY=%.3f",
            epoch + 1, float(G_loss.numpy()), float(F_loss.numpy()), float(DX_loss.numpy()), float(DY_loss.numpy())
        )
        sample = next(iter(trainX))
        fake_y = G(sample, training=False)
        plt.figure(figsize=(6, 3))
        plt.subplot(1, 2, 1); plt.imshow((sample[0] + 1) / 2); plt.title("Real X"); plt.axis("off")
        plt.subplot(1, 2, 2); plt.imshow((fake_y[0] + 1) / 2); plt.title("Fake Y"); plt.axis("off")
        plt.tight_layout()
        out_path = exp.results_dir / f"epoch_{epoch+1}.png"
        plt.savefig(out_path)
        plt.close()
        
print("\nGenerating final comparison on test image...")

#save model
experiment_root = exp.results_dir.parent
export_dir = experiment_root / "exported"
export_dir.mkdir(parents=True, exist_ok=True)

G.save(str(export_dir / "G_full.keras"))
F.save(str(export_dir / "F_full.keras"))

G.save_weights(str(export_dir / "G_full.weights.h5"))
F.save_weights(str(export_dir / "F_full.weights.h5"))

G.export(str(export_dir / "G_savedmodel"))
F.export(str(export_dir / "F_savedmodel"))

print("\nSaved model artifacts:")
print("  G_full.keras")
print("  F_full.keras")
print("  G_full.weights.h5")
print("  F_full.weights.h5")
print("  G_savedmodel/")
print("  F_savedmodel/")