In [None]:
# hello world

For clean output

In [None]:
from IPython.display import clear_output

In [None]:
%pip install -r requirements.txt

Imports

In [None]:
import os
import random
from glob import glob

import tensorflow as tf

Configuation

In [None]:
MAX_TRAIN_IMAGES = 400

TRAIN_VAL_IMAGE_DIR = "data/lol_dataset/our485/low"
TEST_IMAGE_DIR = "data/lol_dataset/eval15/low"

Dataset accessing

In [None]:
train_val_image_files = glob(os.path.join(TRAIN_VAL_IMAGE_DIR, "*.png"))
test_image_files = glob(os.path.join(TEST_IMAGE_DIR, "*.png"))

random.shuffle(train_val_image_files)

train_image_files = train_val_image_files[:MAX_TRAIN_IMAGES]
val_image_files = train_val_image_files[MAX_TRAIN_IMAGES:]

print("Number of Training Images:", len(train_image_files))
print("Number of Validation Images:", len(val_image_files))
print("Number of Test Images from LOL Dataset:", len(test_image_files))


Data pairing
Load the data

In [None]:
def load_data(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
    image = image / 255.0
    return image


def get_dataset(images):
    dataset = tf.data.Dataset.from_tensor_slices((images))
    dataset = dataset.map(load_data, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(AUTOTUNE)
    
    return dataset


train_dataset = get_dataset(train_image_files)
val_dataset = get_dataset(val_image_files)

In [None]:
print("\033[94m")
print("Train Data Elements:", train_dataset.element_spec)
print("Validation Data Elements:", val_dataset.element_spec)

Check few images

In [None]:
images = next(iter(train_dataset)).numpy()

fig = plt.figure(figsize=(16, 16))
grid = ImageGrid(fig, 111, nrows_ncols=(4, 4), axes_pad=0.1)

random_images = images[np.random.choice(np.arange(images.shape[0]), 16)]

for ax, image in zip(grid, random_images):
    image = image * 255.0
    ax.imshow(image.astype(np.uint8))

plt.title("Sample Training Images")
plt.show()

In [None]:
def build_dce_net(image_size=None) -> keras.Model:
    input_image = keras.Input(shape=[image_size, image_size, 3])

    conv1 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(input_image)

    conv2 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(conv1)

    conv3 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(conv2)

    conv4 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(conv3)

    int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])

    conv5 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(int_con1)

    int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])

    conv6 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(int_con2)

    int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])

    x_r = layers.Conv2D(
        24, (3, 3), strides=(1, 1), activation="tanh", padding="same"
    )(int_con3)

    return keras.Model(inputs=input_image, outputs=x_r)


In [None]:
def color_constancy_loss(x):
    mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
    mean_red = mean_rgb[:, :, :, 0]
    mean_green = mean_rgb[:, :, :, 1]
    mean_blue = mean_rgb[:, :, :, 2]

    diff_red_green = tf.square(mean_red - mean_green)
    diff_red_blue = tf.square(mean_red - mean_blue)
    diff_green_blue = tf.square(mean_blue - mean_green)

    return tf.sqrt(
        tf.square(diff_red_green) +
        tf.square(diff_red_blue) +
        tf.square(diff_green_blue)
    )


In [None]:
def exposure_loss(x, mean_val=0.6):
    x = tf.reduce_mean(x, axis=3, keepdims=True)
    mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
    return tf.reduce_mean(tf.square(mean - mean_val))

In [None]:
def illumination_smoothness_loss(x):
    """Inspired from https://github.com/tuvovan/Zero_DCE_TF/blob/master/src/loss.py#L28"""
    batch_size = tf.shape(x)[0]
    height_x = tf.shape(x)[1]
    width_x = tf.shape(x)[2]
    count_height = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
    count_width = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
    height_total_variance = tf.reduce_sum(
        tf.square((x[:, 1:, :, :] - x[:, : height_x - 1, :, :]))
    )
    width_total_variance = tf.reduce_sum(
        tf.square((x[:, :, 1:, :] - x[:, :, : width_x - 1, :]))
    )
    batch_size = tf.cast(batch_size, dtype=tf.float32)
    count_height = tf.cast(count_height, dtype=tf.float32)
    count_width = tf.cast(count_width, dtype=tf.float32)
    return 2 * (
        height_total_variance / count_height + width_total_variance / count_width
    ) / batch_size

In [None]:
class SpatialConsistencyLoss(losses.Loss):
    def __init__(self, **kwargs):
        super(SpatialConsistencyLoss, self).__init__(reduction="none")

        self.left_kernel = tf.constant(
            [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
        )
        self.right_kernel = tf.constant(
            [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
        )
        self.up_kernel = tf.constant(
            [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
        )
        self.down_kernel = tf.constant(
            [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
        )

    def call(self, y_true, y_pred):

        original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
        enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
        original_pool = tf.nn.avg_pool2d(
            original_mean, ksize=4, strides=4, padding="VALID"
        )
        enhanced_pool = tf.nn.avg_pool2d(
            enhanced_mean, ksize=4, strides=4, padding="VALID"
        )

        d_original_left = tf.nn.conv2d(
            original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_original_right = tf.nn.conv2d(
            original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_original_up = tf.nn.conv2d(
            original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_original_down = tf.nn.conv2d(
            original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )

        d_enhanced_left = tf.nn.conv2d(
            enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_enhanced_right = tf.nn.conv2d(
            enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_enhanced_up = tf.nn.conv2d(
            enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_enhanced_down = tf.nn.conv2d(
            enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )

        d_left = tf.square(d_original_left - d_enhanced_left)
        d_right = tf.square(d_original_right - d_enhanced_right)
        d_up = tf.square(d_original_up - d_enhanced_up)
        d_down = tf.square(d_original_down - d_enhanced_down)
        return d_left + d_right + d_up + d_down

In [None]:
class ZeroDCE(keras.Model):
    def __init__(self, **kwargs):
        super(ZeroDCE, self).__init__(**kwargs)
        self.dce_model = build_dce_net(IMAGE_SIZE)

    def compile(self, learning_rate, **kwargs):
        super(ZeroDCE, self).compile(**kwargs)
        self.optimizer = optimizers.Adam(learning_rate=learning_rate)
        self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
    
    def summary(self, *args, **kwargs):
        self.dce_model.summary(*args, **kwargs)

    def get_enhanced_image(self, data, output):
        x = data
        for i in range(0, 3 * 8, 3):
            r = output[:, :, :, i: i + 3]
            x = x + r * (tf.square(x) - x)
        return x

    def call(self, data):
        dce_net_output = self.dce_model(data)
        return self.get_enhanced_image(data, dce_net_output)

    def compute_losses(self, data, output):
        enhanced_image = self.get_enhanced_image(data, output)

        loss_illumination = 200 * illumination_smoothness_loss(output)

        loss_spatial_constancy = tf.reduce_mean(
            self.spatial_constancy_loss(enhanced_image, data)
        )

        loss_color_constancy = 5 * tf.reduce_mean(
            color_constancy_loss(enhanced_image)
        )

        loss_exposure = 10 * tf.reduce_mean(
            exposure_loss(enhanced_image)
        )

        total_loss = (
            loss_illumination
            + loss_spatial_constancy
            + loss_color_constancy
            + loss_exposure
        )

        return {
            "total_loss": total_loss,
            "illumination_smoothness_loss": loss_illumination,
            "spatial_constancy_loss": loss_spatial_constancy,
            "color_constancy_loss": loss_color_constancy,
            "exposure_loss": loss_exposure,
        }

    def train_step(self, data):
        with tf.GradientTape() as tape:
            output = self.dce_model(data)
            losses = self.compute_losses(data, output)

        gradients = tape.gradient(
            losses["total_loss"],
            self.dce_model.trainable_weights
        )

        self.optimizer.apply_gradients(
            zip(gradients, self.dce_model.trainable_weights)
        )

        return losses

    def test_step(self, data):
        output = self.dce_model(data)
        return self.compute_losses(data, output)

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        self.dce_model.save_weights(
            filepath, overwrite=overwrite,
            save_format=save_format, options=options
        )

    def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
        self.dce_model.load_weights(
            filepath=filepath,
            by_name=by_name,
            skip_mismatch=skip_mismatch,
            options=options,
        )


In [None]:
zero_dce_model = ZeroDCE()
zero_dce_model.summary()

In [None]:
def plot_results(images, titles, figure_size=(12, 12)):
    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
        _ = plt.imshow(images[i])
        plt.axis("off")
    plt.show()


def infer(original_image):
    original_image = original_image.resize((IMAGE_SIZE, IMAGE_SIZE))

    image = keras.preprocessing.image.img_to_array(original_image)
    image = image[:, :, :3] if image.shape[-1] > 3 else image

    image = image.astype("float32") / 255.0
    image = np.expand_dims(image, axis=0)

    output_image = zero_dce_model(image, training=False)
    output_image = tf.cast(output_image[0] * 255, dtype=tf.uint8)

    output_image = Image.fromarray(output_image.numpy())
    return output_image



In [None]:
class LogPredictionCallback(callbacks.Callback):

    def __init__(self, image_files, log_interval, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.image_files = image_files
        self.log_interval = log_interval

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.log_interval == 0:
            for image_file in self.image_files:
                original_image = Image.open(image_file).convert("RGB")
                enhanced_image = infer(original_image)
                plot_results(
                    [original_image, enhanced_image],
                    ["Original", "Enhanced_Image"],
                    (15, 7),
                )


In [None]:
zero_dce_model.compile(learning_rate=LEARNING_RATE)
history = zero_dce_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    callbacks=[
        LogPredictionCallback(
            image_files=random.sample(val_image_files, 4),
            log_interval=LOG_INTERVALS
        )
    ]
)

## ðŸ“Š Performance Evaluation

Now that the model is trained, let's **quantitatively measure** how well it performs.

We compare the **enhanced images** (model output) with the **ground truth** (well-lit images from `eval15/high/`).

### Metrics Used:
| Metric | What it Measures | Good Value |
|--------|-----------------|------------|
| **PSNR** (Peak Signal-to-Noise Ratio) | Pixel-level accuracy in dB | > 20 dB |
| **SSIM** (Structural Similarity Index) | Brightness, contrast & structure similarity | > 0.80 |
| **MAE** (Mean Absolute Error) | Average pixel difference | Close to 0 |

Preprocessisng

Dataset loader class ya function

Model Architecture (Convolutional Autoencoder)

Output Activation

Loss function (MAE)

Trainng cofiguration

In [None]:
# ============================================================================
# CELL: Plot Training Loss Curves
# ============================================================================
# WHY: Loss curves show whether the model actually LEARNED during training.
#   - If curves go DOWN and flatten â†’ the model converged (good!)
#   - If curves oscillate wildly â†’ the model is unstable
#   - If curves go UP â†’ the model diverged (bad!)
#
# We plot all 5 losses:
#   1. Total Loss         â€” the combined loss the optimizer minimizes
#   2. Illumination Loss  â€” keeps the curve maps (alpha) smooth, no patchy artifacts
#   3. Spatial Loss       â€” preserves edges and structure from the original
#   4. Color Loss         â€” prevents unnatural color tints (e.g., too blue/green)
#   5. Exposure Loss      â€” pushes average brightness toward a well-lit target (0.6)
# ============================================================================

loss_names = [
    "total_loss",
    "illumination_smoothness_loss",
    "spatial_constancy_loss",
    "color_constancy_loss",
    "exposure_loss",
]

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle("Training & Validation Loss Curves", fontsize=16, fontweight="bold")
axes = axes.flatten()
colors = ["#e74c3c", "#3498db", "#2ecc71", "#f39c12", "#9b59b6"]

for idx, key in enumerate(loss_names):
    ax = axes[idx]

    # Plot training loss for this metric
    if key in history.history:
        ax.plot(history.history[key], color=colors[idx], linewidth=2, label="Train")

    # Plot validation loss (Keras prefixes with 'val_')
    val_key = f"val_{key}"
    if val_key in history.history:
        ax.plot(history.history[val_key], color=colors[idx],
                linewidth=2, linestyle="--", alpha=0.7, label="Validation")

    ax.set_title(key.replace("_", " ").title(), fontsize=12)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True, alpha=0.3)

axes[5].axis("off")  # Hide the unused 6th subplot
plt.tight_layout()
plt.show()

# Print the final loss values from the last epoch
print("\n" + "=" * 60)
print("FINAL LOSS VALUES (Last Epoch)")
print("=" * 60)
for key in loss_names:
    if key in history.history:
        train_val = history.history[key][-1]
        val_key = f"val_{key}"
        if val_key in history.history:
            val_val = history.history[val_key][-1]
            print(f"  {key:40s} | Train: {train_val:.6f} | Val: {val_val:.6f}")
        else:
            print(f"  {key:40s} | Train: {train_val:.6f}")

In [None]:
# ============================================================================
# CELL: Define Performance Metric Functions
# ============================================================================
# These functions compute how similar the ENHANCED image is to the GROUND TRUTH.
# We use three standard metrics from image processing research.
# ============================================================================


def compute_psnr(enhanced, ground_truth):
    """
    PSNR (Peak Signal-to-Noise Ratio)
    ----------------------------------
    HOW IT WORKS:
    1. Find error at each pixel: error = enhanced - ground_truth
    2. Compute MSE (Mean Squared Error) = average of error^2
       â†’ Squaring penalizes large errors MORE than small ones
    3. Convert to decibel scale: PSNR = 10 * log10(MAX^2 / MSE)
       â†’ MAX = 1.0 since our images are normalized to [0, 1]

    INTERPRETATION:
    - > 30 dB: Excellent (nearly identical to ground truth)
    - 25-30 dB: Good
    - 20-25 dB: Acceptable
    - < 20 dB: Poor

    Higher PSNR = Better quality.
    """
    return tf.image.psnr(enhanced, ground_truth, max_val=1.0).numpy()


def compute_ssim(enhanced, ground_truth):
    """
    SSIM (Structural Similarity Index)
    ------------------------------------
    HOW IT WORKS:
    Instead of comparing individual pixels, SSIM compares small WINDOWS
    (patches) of two images on three aspects:

    1. LUMINANCE: Are both windows equally bright?
       l = (2 * mean_x * mean_y + C1) / (mean_x^2 + mean_y^2 + C1)

    2. CONTRAST: Do both windows have similar variation?
       c = (2 * std_x * std_y + C2) / (std_x^2 + std_y^2 + C2)

    3. STRUCTURE: Do both windows have similar patterns/textures?
       s = (covariance_xy + C3) / (std_x * std_y + C3)

    Final SSIM = average of (l * c * s) across all windows.

    INTERPRETATION:
    - Range: 0 to 1 (1.0 = perfectly identical)
    - > 0.90: Excellent | 0.80-0.90: Good | < 0.60: Poor
    """
    return tf.image.ssim(enhanced, ground_truth, max_val=1.0).numpy()


def compute_mae(enhanced, ground_truth):
    """
    MAE (Mean Absolute Error)
    --------------------------
    HOW IT WORKS:
    Simply the average of |enhanced_pixel - ground_truth_pixel|
    for every pixel in the image.

    MAE = (1/N) * sum(|enhanced - ground_truth|)

    Unlike MSE, MAE doesn't over-penalize large errors.
    Lower = Better (0.0 = identical images)
    """
    return tf.reduce_mean(tf.abs(enhanced - ground_truth)).numpy()


print("âœ… Metric functions defined: compute_psnr, compute_ssim, compute_mae")

In [None]:
# ============================================================================
# CELL: Evaluate Model on Test Set (eval15)
# ============================================================================
# This is the CORE evaluation cell. For each of the 15 test images:
#
# 1. Load the LOW-LIGHT image (input) from eval15/low/
# 2. Load the GROUND TRUTH image (target) from eval15/high/
# 3. Pass the low-light image through our trained model â†’ ENHANCED image
# 4. Compute PSNR, SSIM, MAE between:
#    a) ORIGINAL (low) vs GROUND TRUTH  â†’ baseline (how bad the input is)
#    b) ENHANCED vs GROUND TRUTH        â†’ model performance
#
# The IMPROVEMENT = (b) - (a) tells us how much the model helped.
# ============================================================================

# Paths to the test set â€” these images were NOT used during training
TEST_LOW_DIR = "data/lol_dataset/eval15/low"
TEST_HIGH_DIR = "data/lol_dataset/eval15/high"

# Get sorted file lists so low[i] matches high[i]
test_low_files = sorted(glob(os.path.join(TEST_LOW_DIR, "*.png")))
test_high_files = sorted(glob(os.path.join(TEST_HIGH_DIR, "*.png")))

print(f"Evaluating on {len(test_low_files)} test images...\n")

# Store results for each image
eval_results = {
    "filenames": [],
    "psnr_original": [], "ssim_original": [], "mae_original": [],
    "psnr_enhanced": [], "ssim_enhanced": [], "mae_enhanced": [],
}

for low_path, high_path in zip(test_low_files, test_high_files):
    filename = os.path.basename(low_path)

    # --- Step 1: Load low-light image and normalize to [0, 1] ---
    low_img = tf.io.read_file(low_path)
    low_img = tf.image.decode_png(low_img, channels=3)
    low_img = tf.image.resize(low_img, [IMAGE_SIZE, IMAGE_SIZE])
    low_img = tf.cast(low_img, tf.float32) / 255.0

    # --- Step 2: Load ground truth (well-lit) image ---
    high_img = tf.io.read_file(high_path)
    high_img = tf.image.decode_png(high_img, channels=3)
    high_img = tf.image.resize(high_img, [IMAGE_SIZE, IMAGE_SIZE])
    high_img = tf.cast(high_img, tf.float32) / 255.0

    # --- Step 3: Enhance the low-light image using our trained model ---
    # The model expects a batch, so we add a dimension: (H,W,3) â†’ (1,H,W,3)
    low_batch = tf.expand_dims(low_img, axis=0)
    enhanced_img = zero_dce_model(low_batch, training=False)
    # Clip to valid range [0, 1] to avoid display artifacts
    enhanced_img = tf.clip_by_value(enhanced_img[0], 0.0, 1.0)

    # --- Step 4: Compute metrics ---
    # Add batch dimension for TF metric functions
    enhanced_batch = tf.expand_dims(enhanced_img, 0)
    high_batch = tf.expand_dims(high_img, 0)

    # Original (low-light) vs Ground Truth â€” the BASELINE
    psnr_o = compute_psnr(low_batch, high_batch)[0]
    ssim_o = compute_ssim(low_batch, high_batch)[0]
    mae_o = compute_mae(low_img, high_img)

    # Enhanced vs Ground Truth â€” the MODEL'S PERFORMANCE
    psnr_e = compute_psnr(enhanced_batch, high_batch)[0]
    ssim_e = compute_ssim(enhanced_batch, high_batch)[0]
    mae_e = compute_mae(enhanced_img, high_img)

    # Store results
    eval_results["filenames"].append(filename)
    eval_results["psnr_original"].append(psnr_o)
    eval_results["ssim_original"].append(ssim_o)
    eval_results["mae_original"].append(mae_o)
    eval_results["psnr_enhanced"].append(psnr_e)
    eval_results["ssim_enhanced"].append(ssim_e)
    eval_results["mae_enhanced"].append(mae_e)

    # Print per-image results with arrows showing improvement
    print(f"  {filename:20s} | "
          f"PSNR: {psnr_o:.2f} â†’ {psnr_e:.2f} dB | "
          f"SSIM: {ssim_o:.4f} â†’ {ssim_e:.4f} | "
          f"MAE: {mae_o:.4f} â†’ {mae_e:.4f}")

# --- Print Summary ---
avg_psnr_o = np.mean(eval_results["psnr_original"])
avg_psnr_e = np.mean(eval_results["psnr_enhanced"])
avg_ssim_o = np.mean(eval_results["ssim_original"])
avg_ssim_e = np.mean(eval_results["ssim_enhanced"])
avg_mae_o = np.mean(eval_results["mae_original"])
avg_mae_e = np.mean(eval_results["mae_enhanced"])

print("\n" + "=" * 70)
print("  PERFORMANCE SUMMARY â€” Enhanced vs Ground Truth (Test Set)")
print("=" * 70)
print(f"  {'Metric':<25s} | {'Original (Low)':<15s} | {'Enhanced':<15s} | {'Improvement':<15s}")
print(f"  {'-'*25}-+-{'-'*15}-+-{'-'*15}-+-{'-'*15}")
print(f"  {'PSNR (dB) â†‘ better':<25s} | {avg_psnr_o:>13.2f}  | {avg_psnr_e:>13.2f}  | +{avg_psnr_e - avg_psnr_o:.2f} dB")
print(f"  {'SSIM (0-1) â†‘ better':<25s} | {avg_ssim_o:>13.4f}  | {avg_ssim_e:>13.4f}  | +{avg_ssim_e - avg_ssim_o:.4f}")
print(f"  {'MAE (0-1) â†“ better':<25s} | {avg_mae_o:>13.4f}  | {avg_mae_e:>13.4f}  | {avg_mae_e - avg_mae_o:+.4f}")
print("\n  Interpretation:")
print("  â€¢ PSNR > 20 dB: Acceptable  |  > 25 dB: Good  |  > 30 dB: Excellent")
print("  â€¢ SSIM > 0.60: Moderate  |  > 0.80: Good  |  > 0.90: Excellent")
print("  â€¢ MAE closer to 0 = better pixel-level accuracy")

In [None]:
# ============================================================================
# CELL: Bar Chart â€” Per-Image Metric Comparison
# ============================================================================
# This creates side-by-side bar charts comparing ORIGINAL vs ENHANCED
# for each test image. The RED bars (original) should be worse than
# the GREEN bars (enhanced) for PSNR and SSIM.
# For MAE, GREEN bars should be LOWER (shorter) than RED.
# ============================================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle("Per-Image Metrics: Original vs Enhanced",
             fontsize=16, fontweight="bold", y=1.02)

x = np.arange(len(eval_results["filenames"]))
width = 0.35  # Width of each bar

# --- PSNR Bar Chart (Higher = Better) ---
axes[0].bar(x - width/2, eval_results["psnr_original"], width,
            label="Original (Low)", color="#e74c3c", alpha=0.8)
axes[0].bar(x + width/2, eval_results["psnr_enhanced"], width,
            label="Enhanced", color="#2ecc71", alpha=0.8)
axes[0].set_title("PSNR (dB) â€” Higher is Better â†‘", fontsize=13, fontweight="bold")
axes[0].set_xlabel("Test Image")
axes[0].set_ylabel("PSNR (dB)")
axes[0].set_xticks(x)
axes[0].set_xticklabels([f.split('.')[0] for f in eval_results["filenames"]],
                        rotation=45, ha="right")
axes[0].legend()
axes[0].grid(axis="y", alpha=0.3)

# --- SSIM Bar Chart (Higher = Better) ---
axes[1].bar(x - width/2, eval_results["ssim_original"], width,
            label="Original (Low)", color="#e74c3c", alpha=0.8)
axes[1].bar(x + width/2, eval_results["ssim_enhanced"], width,
            label="Enhanced", color="#2ecc71", alpha=0.8)
axes[1].set_title("SSIM â€” Higher is Better â†‘", fontsize=13, fontweight="bold")
axes[1].set_xlabel("Test Image")
axes[1].set_ylabel("SSIM")
axes[1].set_xticks(x)
axes[1].set_xticklabels([f.split('.')[0] for f in eval_results["filenames"]],
                        rotation=45, ha="right")
axes[1].legend()
axes[1].grid(axis="y", alpha=0.3)
axes[1].set_ylim(0, 1.0)

# --- MAE Bar Chart (Lower = Better) ---
axes[2].bar(x - width/2, eval_results["mae_original"], width,
            label="Original (Low)", color="#e74c3c", alpha=0.8)
axes[2].bar(x + width/2, eval_results["mae_enhanced"], width,
            label="Enhanced", color="#2ecc71", alpha=0.8)
axes[2].set_title("MAE â€” Lower is Better â†“", fontsize=13, fontweight="bold")
axes[2].set_xlabel("Test Image")
axes[2].set_ylabel("MAE")
axes[2].set_xticks(x)
axes[2].set_xticklabels([f.split('.')[0] for f in eval_results["filenames"]],
                        rotation=45, ha="right")
axes[2].legend()
axes[2].grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# CELL: Side-by-Side Visual Comparison with Annotated Metrics
# ============================================================================
# This is the most IMPRESSIVE cell for your teacher/judge.
# It shows 5 test images as:
#   [LOW-LIGHT input]  â†’  [ENHANCED by model]  â†’  [GROUND TRUTH target]
# with PSNR, SSIM, MAE values printed above each enhanced image.
#
# WHAT TO LOOK FOR:
# - The enhanced image should be BRIGHTER than the original
# - The enhanced image should look SIMILAR to the ground truth
# - Colors should be natural (no weird tints)
# - Details/edges should be preserved (not blurry)
# ============================================================================

num_display = min(5, len(test_low_files))  # Show up to 5 images

fig, axes = plt.subplots(num_display, 3, figsize=(15, 5 * num_display))
fig.suptitle("Visual Comparison: Original â†’ Enhanced â†’ Ground Truth",
             fontsize=18, fontweight="bold", y=1.01)

for row in range(num_display):
    low_path = test_low_files[row]
    high_path = test_high_files[row]

    # Load and normalize images
    low_img = tf.cast(tf.image.resize(
        tf.image.decode_png(tf.io.read_file(low_path), channels=3),
        [IMAGE_SIZE, IMAGE_SIZE]), tf.float32) / 255.0

    high_img = tf.cast(tf.image.resize(
        tf.image.decode_png(tf.io.read_file(high_path), channels=3),
        [IMAGE_SIZE, IMAGE_SIZE]), tf.float32) / 255.0

    # Enhance
    enhanced_img = zero_dce_model(tf.expand_dims(low_img, 0), training=False)
    enhanced_img = tf.clip_by_value(enhanced_img[0], 0.0, 1.0)

    # Compute metrics for this image
    e_batch = tf.expand_dims(enhanced_img, 0)
    h_batch = tf.expand_dims(high_img, 0)
    psnr_val = compute_psnr(e_batch, h_batch)[0]
    ssim_val = compute_ssim(e_batch, h_batch)[0]
    mae_val = compute_mae(enhanced_img, high_img)

    # Plot the three versions side by side
    axes[row, 0].imshow(low_img.numpy())
    axes[row, 0].set_title("Low-Light (Input)", fontsize=12)
    axes[row, 0].axis("off")

    axes[row, 1].imshow(enhanced_img.numpy())
    axes[row, 1].set_title(
        f"Enhanced\nPSNR: {psnr_val:.2f} dB | SSIM: {ssim_val:.4f} | MAE: {mae_val:.4f}",
        fontsize=11, color="green"
    )
    axes[row, 1].axis("off")

    axes[row, 2].imshow(high_img.numpy())
    axes[row, 2].set_title("Ground Truth", fontsize=12)
    axes[row, 2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# CELL: Brightness Histogram Comparison
# ============================================================================
# A histogram shows the DISTRIBUTION of brightness values in an image.
#
# WHAT TO LOOK FOR:
# - Original (red curve): Clustered on the LEFT (near 0 = dark pixels)
#   â†’ This confirms the image is indeed dark/low-light
#
# - Enhanced (green curve): Should SHIFT TO THE RIGHT (brighter)
#   â†’ The model is successfully brightening the image
#
# - Ground Truth (blue curve): The "target" brightness distribution
#   â†’ If the green curve closely matches the blue curve, the model
#     is doing a great job at matching the correct brightness!
#
# This analysis gives VISUAL PROOF that brightness is being corrected.
# ============================================================================

num_hist = min(3, len(test_low_files))  # Show histograms for 3 images

fig, axes = plt.subplots(num_hist, 2, figsize=(16, 5 * num_hist))
fig.suptitle("Brightness Histogram Analysis",
             fontsize=18, fontweight="bold", y=1.01)

for row in range(num_hist):
    low_path = test_low_files[row]
    high_path = test_high_files[row]

    # Load images
    low_img = tf.cast(tf.image.resize(
        tf.image.decode_png(tf.io.read_file(low_path), channels=3),
        [IMAGE_SIZE, IMAGE_SIZE]), tf.float32) / 255.0

    high_img = tf.cast(tf.image.resize(
        tf.image.decode_png(tf.io.read_file(high_path), channels=3),
        [IMAGE_SIZE, IMAGE_SIZE]), tf.float32) / 255.0

    # Enhance
    enhanced_img = zero_dce_model(tf.expand_dims(low_img, 0), training=False)
    enhanced_img = tf.clip_by_value(enhanced_img[0], 0.0, 1.0)

    # Convert to grayscale brightness (average across RGB channels)
    # This gives us a single brightness value per pixel
    low_gray = tf.reduce_mean(low_img, axis=-1).numpy().flatten()
    enhanced_gray = tf.reduce_mean(enhanced_img, axis=-1).numpy().flatten()
    high_gray = tf.reduce_mean(high_img, axis=-1).numpy().flatten()

    # LEFT: Show the three image versions side by side
    combined = np.concatenate([
        low_img.numpy(), enhanced_img.numpy(), high_img.numpy()
    ], axis=1)
    axes[row, 0].imshow(combined)
    axes[row, 0].set_title("Original  |  Enhanced  |  Ground Truth", fontsize=12)
    axes[row, 0].axis("off")

    # RIGHT: Plot overlapping histograms
    axes[row, 1].hist(low_gray, bins=100, alpha=0.5, color="#e74c3c",
                      label="Original (Low)", density=True)
    axes[row, 1].hist(enhanced_gray, bins=100, alpha=0.5, color="#2ecc71",
                      label="Enhanced", density=True)
    axes[row, 1].hist(high_gray, bins=100, alpha=0.5, color="#3498db",
                      label="Ground Truth", density=True)
    axes[row, 1].set_title(f"Brightness Distribution â€” {os.path.basename(low_path)}",
                           fontsize=12)
    axes[row, 1].set_xlabel("Pixel Brightness (0 = black, 1 = white)")
    axes[row, 1].set_ylabel("Density")
    axes[row, 1].legend()
    axes[row, 1].grid(True, alpha=0.3)
    axes[row, 1].set_xlim(0, 1)

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# CELL: Final Evaluation Report
# ============================================================================
# A clean, comprehensive summary of the model's performance.
# Perfect for including in your project report or showing to the teacher.
# ============================================================================

avg_psnr_o = np.mean(eval_results["psnr_original"])
avg_psnr_e = np.mean(eval_results["psnr_enhanced"])
avg_ssim_o = np.mean(eval_results["ssim_original"])
avg_ssim_e = np.mean(eval_results["ssim_enhanced"])
avg_mae_o = np.mean(eval_results["mae_original"])
avg_mae_e = np.mean(eval_results["mae_enhanced"])

print("=" * 70)
print("   FINAL EVALUATION REPORT â€” Zero-DCE Low-Light Enhancement")
print("=" * 70)
print()
print("  Model Architecture : Zero-DCE (Zero-Reference Deep Curve Estimation)")
print("  Training Approach  : Unsupervised (no paired data needed for training)")
print(f"  Training Images    : {MAX_TRAIN_IMAGES}")
print(f"  Test Images        : {len(eval_results['filenames'])}")
print(f"  Image Size         : {IMAGE_SIZE} x {IMAGE_SIZE}")
print(f"  Epochs             : {EPOCHS}")
print(f"  Learning Rate      : {LEARNING_RATE}")
print(f"  Batch Size         : {BATCH_SIZE}")
print()
print(f"  {'Metric':<15s} | {'Before (Low)':<14s} | {'After (Enhanced)':<16s} | {'Improvement':<15s}")
print(f"  {'-'*15}-+-{'-'*14}-+-{'-'*16}-+-{'-'*15}")
print(f"  {'PSNR (dB)':<15s} | {avg_psnr_o:>12.2f}  | {avg_psnr_e:>14.2f}  | +{avg_psnr_e - avg_psnr_o:.2f} dB")
print(f"  {'SSIM':<15s} | {avg_ssim_o:>12.4f}  | {avg_ssim_e:>14.4f}  | +{avg_ssim_e - avg_ssim_o:.4f}")
print(f"  {'MAE':<15s} | {avg_mae_o:>12.4f}  | {avg_mae_e:>14.4f}  | {avg_mae_e - avg_mae_o:+.4f}")
print()
print("  Interpretation Guide:")
print("  " + "-" * 50)
print("  PSNR  â†’ > 20 dB: Acceptable | > 25: Good | > 30: Excellent")
print("  SSIM  â†’ > 0.60: Moderate | > 0.80: Good | > 0.90: Excellent")
print("  MAE   â†’ Closer to 0.0 is better")
print()
print("  Key Observations:")
print("  " + "-" * 50)
print("  â€¢ The model enhances low-light images WITHOUT using paired training data.")
print("  â€¢ It learns enhancement curves purely from unsupervised loss functions:")
print("    - Exposure Loss: targets a well-lit brightness level")
print("    - Color Constancy: prevents unnatural color shifts")
print("    - Illumination Smoothness: ensures smooth, artifact-free enhancement")
print("    - Spatial Consistency: preserves edges and structural details")
print("  â€¢ Ground truth images are used ONLY for evaluation, not training.")
print("=" * 70)