In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras import Model, Input, layers, models, Sequential
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator
from PIL import UnidentifiedImageError, ImageFile

In [None]:
encoder_path = 'model/model_final.keras'
decoder_path = 'model/decoder.keras'
DATA_DIR = "Ground Truth"

In [None]:
def combine_encoder_decoder(encoder_path, decoder_path):
    encoder = load_model(encoder_path)
    decoder = load_model(decoder_path)

    input_shape = encoder.input_shape[1:]

    inputs = Input(shape=input_shape)

    latent_output = encoder(inputs)

    decoded_output = decoder(latent_output)
    
    full_model = Model(inputs=inputs, outputs=decoded_output, name="DualStageNetwork")

    full_model.compile(optimizer='adam', loss='mse')

    return full_model

In [None]:
def load_data(data_dir):
    # drive.mount('/content/drive')

    all_imgs = sorted(
        [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
        key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x))))
    )

    # train_files, test_files = train_test_split(all_imgs, test_size=0.3, random_state=42)
    # val_files, test_files   = train_test_split(test_files, test_size=0.5, random_state=42)

    train_files = all_imgs[:500]
    val_files   = all_imgs[500:608]
    test_files  = all_imgs[608:]

    print(f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
    return train_files, val_files, test_files

In [None]:
def load_images(file_list):
    data = []
    for f in file_list:
        try:
            img = load_img(f, target_size=(128, 128), color_mode="rgb")
            arr = img_to_array(img) / 255.0
            data.append(arr)
        except (OSError, UnidentifiedImageError) as e:
            print(f"⚠️ Skipping corrupted image: {f} ({e})")
    return np.array(data)

In [None]:
def plot_results(model, X_input, Y_target, n=5):
    preds = model.predict(X_input[:n])

    plt.figure(figsize=(10, 6))

    for i in range(n):
        # Input
        plt.subplot(n, 3, 3*i + 1)
        inp = X_input[i]
        if inp.shape[-1] == 2:
            inp = inp[..., :1]        # atau buat heatmap grayscale
        plt.imshow(np.clip(inp, 0, 1), cmap="gray")
        plt.title("Input")
        plt.axis("off")

        # Target
        plt.subplot(n, 3, 3*i + 2)
        plt.imshow(np.clip(Y_target[i], 0, 1))
        plt.title("Target")
        plt.axis("off")

        # Output
        plt.subplot(n, 3, 3*i + 3)
        plt.imshow(np.clip(preds[i], 0, 1))
        plt.title("Output")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
def create_cnn(input_shape, latent_dim):
  model = Sequential([
      Input(shape=input_shape),
      layers.Conv2D(16, (3,3), activation='relu', strides=1),
      layers.MaxPooling2D((2,2)),
      layers.Conv2D(32, (3,3), activation='relu', strides=1),
      layers.MaxPooling2D((2,2)),
      layers.Conv2D(64, (3,3), activation='relu', strides=1),
      layers.MaxPooling2D((2,2)),
      layers.Flatten(),
      layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.01)),
      layers.Dense(latent_dim, activation='linear')
  ])
  return model

In [None]:
def build_decoder(img_size, latent_dim):
  inputs = layers.Input(shape=(latent_dim,))

  x = layers.Dense((img_size//8) * (img_size//8) * 64, activation="relu")(inputs)
  x = layers.Reshape((img_size//8, img_size//8, 64))(x)

  x = layers.Conv2DTranspose(64, (3,3), strides=2, activation="relu", padding="same")(x)
  x = layers.Conv2DTranspose(32, (3,3), strides=2, activation="relu", padding="same")(x)
  x = layers.Conv2DTranspose(16, (3,3), strides=2, activation="relu", padding="same")(x)

  outputs = layers.Conv2DTranspose(3, (3,3), activation="sigmoid", padding="same")(x)

  return models.Model(inputs, outputs, name="Decoder")

In [None]:
def ae_loss(y_true, y_pred):
    mae = tf.reduce_mean(tf.abs(y_true - y_pred))
    ssim = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    return 0.8 * mae + 0.2 * ssim

In [None]:
X_signal = np.load("latent_rep/final_dataset_baru.npy")

train_files, test_files, val_files = load_data(DATA_DIR)
train_data = load_images(train_files)
val_data   = load_images(val_files)
test_data  = load_images(test_files)

In [None]:
encoder = create_cnn(input_shape=(24,24,2), latent_dim=256)
decoder = build_decoder(img_size=128, latent_dim=256)

inputs = tf.keras.Input(shape=(24,24,2))
latent = encoder(inputs)
outputs = decoder(latent)

autoencoder = tf.keras.Model(inputs, outputs)
autoencoder.compile(optimizer='adam', loss=ae_loss)

In [None]:
autoencoder.fit(X_signal[:500], train_data, epochs=30, batch_size=8)

In [None]:
# full_model = combine_encoder_decoder(encoder_path, decoder_path)
# full_model.summary()
# full_model.fit(X_signal[:500], train_data, epochs=10, batch_size=16)

In [None]:
plot_results(full_model, X_signal, train_data, n=5)
