# Underwater Image Enhancement: Colour Correction and Detail Enhancement using Hybrid Real-ESRGAN

This notebook implements a pipeline for enhancing underwater images by combining **color correction techniques** with **Real-ESRGAN** for super-resolution. The process addresses common underwater imaging issues such as color distortion, low contrast, and loss of detail due to water attenuation. The pipeline consists of two main stages:

1. **Color Correction**: Adjusts color balance and compensates for underwater-specific distortions using techniques like guided filtering and LAB color space stretching.
2. **Detail Enhancement**: Applies Real-ESRGAN, a state-of-the-art super-resolution model, to enhance image details and sharpness.

Additionally, the notebook includes a **performance evaluation** step to compute metrics (PSNR and MSE) by comparing processed images with reference images.

## Objectives
- Correct color casts in underwater images caused by light absorption and scattering.
- Enhance image contrast and detail using advanced image processing and deep learning techniques.
- Evaluate the quality of enhanced images using quantitative metrics.

## Prerequisites
- Google Colab with GPU support (T4 or better recommended).
- Input images or videos in the `input/` directory.
- Reference images in the `reference/` directory for evaluation.
- Pre-trained Real-ESRGAN model weights (`net_g_5000.pth`) uploaded.

Let's proceed with the implementation.

## Step 1: Install Dependencies

This section sets up all necessary dependencies for image processing, video handling, and Real-ESRGAN-based super-resolution.

### If **not** using Conda:

Ensure that you have the following system-level dependencies installed via your OS package manager:

* `ffmpeg`
* `libGL`

You can install the required Python packages manually using `pip`:

In [None]:
# !pip install -q basicsr facexlib gfpgan numpy opencv-python Pillow torch torchvision tqdm realesrgan natsort scipy scikit-image ffmpeg-python

In [None]:
# !sudo apt update
# !sudo apt update && sudo apt install -y libopencv-dev ffmpeg

Get:1 https://dl.yarnpkg.com/debian stable InRelease
Get:2 https://packages.microsoft.com/repos/microsoft-ubuntu-noble-prod noble InRelease [3600 B]
Get:3 https://dl.yarnpkg.com/debian stable/main all Packages [11.8 kB]         [0m[33m
Get:4 https://repo.anaconda.com/pkgs/misc/debrepo/conda stable InRelease [3961 B]0m[33m
Get:5 https://dl.yarnpkg.com/debian stable/main amd64 Packages [11.8 kB]       [0m[33m
Get:6 https://packages.microsoft.com/repos/microsoft-ubuntu-noble-prod noble/main amd64 Packages [36.3 kB]
Get:7 https://repo.anaconda.com/pkgs/misc/debrepo/conda stable/main amd64 Packages [4557 B]
Get:8 https://packages.microsoft.com/repos/microsoft-ubuntu-noble-prod noble/main all Packages [576 B]
Get:9 http://security.ubuntu.com/ubuntu noble-security InRelease [126 kB]      [0m[33m[33m
Get:10 http://archive.ubuntu.com/ubuntu noble InRelease [256 kB]
Get:11 http://security.ubuntu.com/ubuntu noble-security/universe amd64 Packages [1108 kB]33m[33m
Get:12 http://archive.ubu

### If using GitHub Codespaces or a Conda environment (Python 3.12):

All required Python libraries are automatically installed via the Conda environment. System dependencies are auto-configured using the `postCreateCommand` in the `devcontainer.json` file. No additional steps are required unless you want to manually verify or update packages.

#### To create a new Conda environment in GitHub Codespaces:

1. Click on the **Select Kernel** dropdown in the top-right corner.
2. Select **Another Kernel**.
3. Go to **Python Environments** > **Create Python Environment**.
4. Choose **Conda** and set the version to **Python 3.12**.

---

### Optional: Use `uv` for Dependency Sync (pip-based)

If Conda is not preferred, the project also provides a `pyproject.toml` compatible with `uv`. You can use `uv` instead of Conda to manage dependencies.

```bash
pip install uv
uv sync
```

This will create and sync a `.venv` environment based on `pyproject.toml` and `uv.lock`.

Make sure to select the `.venv` as the environment interpreter from the **Select Kernel** menu if you're using this approach.

## Step 2: Fix Dependency Issues

The `basicsr` library has a known import issue with `torchvision`. This script corrects the import statement in the `degradations.py` file to ensure compatibility.

In [1]:
!pip show basicsr

Name: basicsr
Version: 1.4.2
Summary: Open Source Image and Video Super-Resolution Toolbox
Home-page: https://github.com/xinntao/BasicSR
Author: Xintao Wang
Author-email: xintao.wang@outlook.com
License: Apache License 2.0
Location: /workspaces/uccde/.conda/lib/python3.12/site-packages
Requires: addict, future, lmdb, numpy, opencv-python, Pillow, pyyaml, requests, scikit-image, scipy, tb-nightly, torch, torchvision, tqdm, yapf
Required-by: gfpgan, realesrgan


In [5]:
%%writefile dependency-fix.sh
#!/bin/bash
# Fix torchvision import in basicsr/data/degradations.py using relative path
sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' ../.conda/lib/python3.12/site-packages/basicsr/data/degradations.py

Writing dependency-fix.sh


In [15]:
# If not using conda
# %%writefile dependency-fix.sh
# #!/bin/bash
# # Fix torchvision import in basicsr/data/degradations.py
# sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/python/3.12.1/lib/python3.12/site-packages/basicsr/data/degradations.py

In [2]:
!chmod +x dependency-fix.sh
!./dependency-fix.sh

## Step 3: Color Correction Implementation

This section defines the core classes and functions for underwater image color correction. The `GuidedFilter` class implements a guided filter for refining transmission maps, while the `ColourCorrection` class handles color balancing and depth map estimation. Additional functions enhance contrast and color through histogram stretching and LAB color space adjustments.

Key features:
- **Guided Filter**: Smooths transmission maps while preserving edges.
- **Color Compensation**: Adjusts red and blue channels to counter underwater color casts.
- **Depth Map Estimation**: Models light attenuation to estimate scene depth.
- **LAB Stretching**: Enhances luminance and color balance in the LAB color space.

In [3]:
import os
import cv2
import datetime
import numpy as np
import natsort
import math
from scipy import stats
from skimage.color import rgb2hsv, hsv2rgb, rgb2lab, lab2rgb
from multiprocessing import Pool, cpu_count
from functools import partial
import ffmpeg

# Configuration settings
CONFIG = {
    'input_dir': 'input',
    'output_dir': 'output-ip',
    'block_size': 9,
    'gimfilt_radius': 30,
    'eps': 1e-2,
    'rb_compensation_flag': 0,  # 0: Compensate both Red and Blue, 1: Compensate only Red
    'enhancement_strength': 0.6,  # Control enhancement intensity
    'video_extensions': ['.mp4', '.avi', '.mov'],  # Supported video formats
    'output_video_fps': 30,  # Default output video frame rate
    'output_video_codec': 'mp4v',  # Codec for output video
    'temp_video_path': 'temp_output.mp4',  # Temporary video file without audio
}

# Set numpy to ignore overflow warnings
np.seterr(over='ignore')

# Guided Filter Class
class GuidedFilter:
    """Guided filter for image processing to refine transmission maps while preserving edges."""
    def __init__(self, input_image, radius=5, epsilon=0.4):
        self._radius = 2 * radius + 1
        self._epsilon = epsilon
        self._input_image = self._to_float_img(input_image)
        self._init_filter()

    def _to_float_img(self, img):
        if img.dtype == np.float32:
            return img
        return img.astype(np.float32) / 255.0

    def _init_filter(self):
        img = self._input_image
        r = self._radius
        eps = self._epsilon
        ir, ig, ib = img[:, :, 0], img[:, :, 1], img[:, :, 2]
        ksize = (r, r)
        self._ir_mean = cv2.blur(ir, ksize)
        self._ig_mean = cv2.blur(ig, ksize)
        self._ib_mean = cv2.blur(ib, ksize)
        irr = cv2.blur(ir * ir, ksize) - self._ir_mean ** 2 + eps
        irg = cv2.blur(ir * ig, ksize) - self._ir_mean * self._ig_mean
        irb = cv2.blur(ir * ib, ksize) - self._ir_mean * self._ib_mean
        igg = cv2.blur(ig * ig, ksize) - self._ig_mean ** 2 + eps
        igb = cv2.blur(ig * ib, ksize) - self._ig_mean * self._ib_mean
        ibb = cv2.blur(ib * ib, ksize) - self._ib_mean ** 2 + eps
        det = irr * (igg * ibb - igb * igb) - irg * (irg * ibb - igb * irb) + irb * (irg * igb - igg * irb)
        self._irr_inv = (igg * ibb - igb * igb) / det
        self._irg_inv = -(irg * ibb - igb * irb) / det
        self._irb_inv = (irg * igb - igg * irb) / det
        self._igg_inv = (irr * ibb - irb * irb) / det
        self._igb_inv = -(irr * igb - irb * irg) / det
        self._ibb_inv = (irr * igg - irg * irg) / det

    def _compute_coefficients(self, input_p):
        r = self._radius
        ksize = (r, r)
        ir, ig, ib = self._input_image[:, :, 0], self._input_image[:, :, 1], self._input_image[:, :, 2]
        p_mean = cv2.blur(input_p, ksize)
        ipr_cov = cv2.blur(ir * input_p, ksize) - self._ir_mean * p_mean
        ipg_cov = cv2.blur(ig * input_p, ksize) - self._ig_mean * p_mean
        ipb_cov = cv2.blur(ib * input_p, ksize) - self._ib_mean * p_mean
        ar = self._irr_inv * ipr_cov + self._irg_inv * ipg_cov + self._irb_inv * ipb_cov
        ag = self._irg_inv * ipr_cov + self._igg_inv * ipg_cov + self._igb_inv * ipb_cov
        ab = self._irb_inv * ipr_cov + self._igb_inv * ipg_cov + self._ibb_inv * ipb_cov
        b = p_mean - ar * self._ir_mean - ag * self._ig_mean - ab * self._ib_mean
        return cv2.blur(ar, ksize), cv2.blur(ag, ksize), cv2.blur(ab, ksize), cv2.blur(b, ksize)

    def _compute_output(self, ab):
        ar_mean, ag_mean, ab_mean, b_mean = ab
        ir, ig, ib = self._input_image[:, :, 0], self._input_image[:, :, 1], self._input_image[:, :, 2]
        return ar_mean * ir + ag_mean * ig + ab_mean * ib + b_mean

    def filter(self, input_p):
        p_32f = self._to_float_img(input_p)
        ab = self._compute_coefficients(p_32f)
        return self._compute_output(ab)

# Colour Correction Class
class ColourCorrection:
    """Handles underwater image color correction by compensating for color casts and estimating depth."""
    def __init__(self, block_size=CONFIG['block_size'], gimfilt_radius=CONFIG['gimfilt_radius'], eps=CONFIG['eps']):
        self.block_size = block_size
        self.gimfilt_radius = gimfilt_radius
        self.eps = eps

    def _compensate_rb(self, image, flag):
        b, g, r = cv2.split(image.astype(np.float64))
        min_r, max_r = np.min(r), np.max(r)
        min_g, max_g = np.min(g), np.max(g)
        min_b, max_b = np.min(b), np.max(b)
        if max_r == min_r or max_g == min_g or max_b == min_b:
            return image
        r = (r - min_r) / (max_r - min_r)
        g = (g - min_g) / (max_g - min_g)
        b = (b - min_b) / (max_b - min_b)
        mean_r, mean_g, mean_b = np.mean(r), np.mean(g), np.mean(b)
        compensation_strength = 0.4
        if flag == 0:
            r = (r + compensation_strength * (mean_g - mean_r) * (1 - r) * g) * max_r
            b = (b + compensation_strength * (mean_g - mean_b) * (1 - b) * g) * max_b
            g = g * max_g
        elif flag == 1:
            r = (r + compensation_strength * (mean_g - mean_r) * (1 - r) * g) * max_r
            g = g * max_g
            b = b * max_b
        return cv2.merge([np.clip(b, 0, 255).astype(np.uint8),
                         np.clip(g, 0, 255).astype(np.uint8),
                         np.clip(r, 0, 255).astype(np.uint8)])

    def _estimate_background_light(self, img, depth_map):
        img = img.astype(np.float32) / 255.0 if img.dtype == np.uint8 else img
        height, width = img.shape[:2]
        n_bright = int(np.ceil(0.001 * height * width))
        indices = np.argpartition(depth_map.ravel(), -n_bright)[-n_bright:]
        candidates = img.reshape(-1, 3)[indices]
        magnitudes = np.linalg.norm(candidates, axis=1)
        sorted_indices = np.argsort(magnitudes)[::-1]
        top_n = 10
        top_candidates = candidates[sorted_indices[:top_n]]
        atmospheric_light = np.mean(top_candidates, axis=0) * 255.0
        return atmospheric_light

    def _compute_depth_map(self, img):
        img = img.astype(np.float32) / 255.0
        x_1 = np.maximum(img[:, :, 0], img[:, :, 1])
        x_2 = img[:, :, 2]
        return 0.51157954 + 0.50516165 * x_1 - 0.90511117 * x_2

    def _compute_min_depth(self, img, background_light):
        img = img.astype(np.float32) / 255.0
        background_light = background_light / 255.0
        max_values = np.max(np.abs(img - background_light), axis=(0, 1)) / np.maximum(background_light, 1 - background_light)
        return 1 - np.max(max_values)

    def _global_stretching_depth(self, img_l):
        flat = img_l.ravel()
        indices = np.argsort(flat)
        i_min, i_max = flat[indices[len(flat)//1000]], flat[indices[-len(flat)//1000]]
        result = np.clip((img_l - i_min) / (i_max - i_min + 1e-10), 0, 1)
        return cv2.GaussianBlur(result, (3, 3), 0.5)

    def _get_rgb_transmission(self, depth_map):
        return 0.98 ** depth_map, 0.97 ** depth_map, 0.88 ** depth_map

    def _refine_transmission_map(self, transmission_b, transmission_g, transmission_r, img):
        guided_filter = GuidedFilter(img, self.gimfilt_radius, self.eps)
        transmission = np.stack([
            guided_filter.filter(transmission_b),
            guided_filter.filter(transmission_g),
            guided_filter.filter(transmission_r)
        ], axis=-1)
        return transmission

    def _compute_scene_radiance(self, img, transmission, atmospheric_light):
        img = img.astype(np.float32)
        min_transmission = 0.2
        transmission = np.maximum(transmission, min_transmission)
        scene_radiance = (img - atmospheric_light) / transmission + atmospheric_light
        return np.clip(scene_radiance, 0, 255).astype(np.uint8)

    def process(self, img, rb_compensation_flag=CONFIG['rb_compensation_flag']):
        img_compensated = self._compensate_rb(img, rb_compensation_flag)
        depth_map = self._compute_depth_map(img_compensated)
        depth_map = self._global_stretching_depth(depth_map)
        guided_filter = GuidedFilter(img_compensated, self.gimfilt_radius, self.eps)
        refined_depth_map = guided_filter.filter(depth_map)
        refined_depth_map = np.clip(refined_depth_map, 0, 1)
        atmospheric_light = self._estimate_background_light(img_compensated, depth_map)
        d_0 = self._compute_min_depth(img_compensated, atmospheric_light)
        d_f = 6 * (depth_map + d_0)
        transmission_b, transmission_g, transmission_r = self._get_rgb_transmission(d_f)
        transmission = self._refine_transmission_map(transmission_b, transmission_g, transmission_r, img_compensated)
        return self._compute_scene_radiance(img_compensated, transmission, atmospheric_light)

# Image Enhancement Functions
def cal_equalisation(img, ratio):
    return np.clip(img * ratio, 0, 255)

def rgb_equalisation(img):
    img = img.astype(np.float32)
    current_mean = np.mean(img, axis=(0, 1))
    target_mean = 140
    ratio = target_mean / (current_mean + 1e-10)
    ratio = np.clip(ratio, 0.8, 1.2)
    return cal_equalisation(img, ratio)

def stretch_range(r_array, height, width):
    flat = r_array.ravel()
    mode = stats.mode(flat, keepdims=True).mode[0] if flat.size > 0 else np.median(flat)
    mode_indices = np.where(flat == mode)[0]
    mode_index_before = mode_indices[0] if mode_indices.size > 0 else len(flat) // 2
    dr_min = (1 - 0.755) * mode
    max_index = min(len(flat) - 1, len(flat) - int((len(flat) - mode_index_before) * 0.01))
    sr_max = np.sort(flat)[max_index]
    return dr_min, sr_max, mode

def global_stretching_ab(a, height, width):
    return a * (1.05 ** (1 - np.abs(a / 128)))

def basic_stretching(img):
    img = img.astype(np.float64)
    min_vals = np.percentile(img, 2, axis=(0,1))
    max_vals = np.percentile(img, 98, axis=(0,1))
    range_vals = max_vals - min_vals
    min_range = 50
    mask = range_vals < min_range
    max_vals[mask] = min_vals[mask] + min_range
    img = np.clip((img - min_vals) * 255 / (max_vals - min_vals + 1e-10), 0, 255)
    return img.astype(np.uint8)

def global_stretching_luminance(img_l, height, width):
    flat = img_l.ravel()
    indices = np.argsort(flat)
    i_min, i_max = flat[indices[len(flat)//50]], flat[indices[-len(flat)//50]]
    if i_max == i_min:
        i_min, i_max = flat.min(), flat.max()
        if i_max == i_min:
            return img_l
    return np.clip((img_l - i_min) * 95 / (i_max - i_min + 1e-10), 0, 100)

def lab_stretching(scene_radiance):
    scene_radiance = np.clip(scene_radiance, 0, 255).astype(np.uint8)
    original = scene_radiance.copy()
    img_lab = rgb2lab(scene_radiance)
    l, a, b = img_lab[:, :, 0], img_lab[:, :, 1], img_lab[:, :, 2]
    img_lab[:, :, 0] = global_stretching_luminance(l, *scene_radiance.shape[:2])
    img_lab[:, :, 1] = global_stretching_ab(a, *scene_radiance.shape[:2])
    img_lab[:, :, 2] = global_stretching_ab(b, *scene_radiance.shape[:2])
    enhanced = lab2rgb(img_lab) * 255
    blend_factor = CONFIG['enhancement_strength']
    result = blend_factor * enhanced + (1 - blend_factor) * original
    return result

def global_stretching_advanced(r_array, height, width, lambda_val, k_val):
    flat = r_array.ravel()
    indices = np.argsort(flat)
    i_min, i_max = flat[indices[len(flat)//100]], flat[indices[-len(flat)//100]]
    dr_min, sr_max, mode = stretch_range(r_array, height, width)
    t_n = lambda_val ** 2
    o_max_left = sr_max * t_n * k_val / mode
    o_max_right = 255 * t_n * k_val / mode
    dif = o_max_right - o_max_left
    if dif >= 1:
        indices = np.arange(1, int(dif) + 1)
        sum_val = np.sum((1.326 + indices) * mode / (t_n * k_val))
        dr_max = sum_val / int(dif)
        p_out = np.where(r_array < i_min, (r_array - i_min) * (dr_min / i_min) + i_min,
                         np.where(r_array > i_max, (r_array - dr_max) * (dr_max / i_max) + i_max,
                                  ((r_array - i_min) * (255 - i_min) / (i_max - i_min) + i_min)))
    else:
        p_out = np.where(r_array < i_min, (r_array - r_array.min()) * (dr_min / r_array.min()) + r_array.min(),
                         ((r_array - i_min) * (255 - dr_min) / (i_max - i_min) + dr_min))
    return p_out

def relative_stretching(scene_radiance, height, width):
    scene_radiance = scene_radiance.astype(np.float64)
    scene_radiance[:, :, 0] = global_stretching_advanced(scene_radiance[:, :, 0], height, width, 0.98, 1.1)
    scene_radiance[:, :, 1] = global_stretching_advanced(scene_radiance[:, :, 1], height, width, 0.97, 1.1)
    scene_radiance[:, :, 2] = global_stretching_advanced(scene_radiance[:, :, 2], height, width, 0.88, 0.9)
    return scene_radiance

def image_enhancement(scene_radiance):
    if scene_radiance.shape[2] == 3:
        scene_radiance = cv2.cvtColor(scene_radiance, cv2.COLOR_BGR2RGB)
    if np.max(scene_radiance) == np.min(scene_radiance):
        return scene_radiance
    original = scene_radiance.copy()
    scene_radiance = scene_radiance.astype(np.float64)
    scene_radiance = basic_stretching(scene_radiance)
    scene_radiance = lab_stretching(scene_radiance)
    final_blend = 0.8
    result = final_blend * scene_radiance + (1 - final_blend) * original
    return np.clip(result, 0, 255).astype(np.uint8)

# Processing Functions
def process_image(file, input_dir=CONFIG['input_dir'], output_dir=CONFIG['output_dir']):
    file_path = os.path.join(input_dir, file)
    base_name, extension = os.path.splitext(file)
    print(f'Processing image: {file}')
    img = cv2.imread(file_path)
    if img is None:
        print(f"Could not read image: {file}")
        return
    colour_corrector = ColourCorrection()
    print("Applying color correction...")
    corrected_img = colour_corrector.process(img)
    print("Applying image enhancement...")
    final_result = image_enhancement(corrected_img)
    final_result_bgr = cv2.cvtColor(final_result, cv2.COLOR_RGB2BGR)
    output_file = os.path.join(output_dir, f'{base_name}_ColourCorrected{extension}')
    cv2.imwrite(output_file, final_result_bgr)
    print(f"Completed processing image: {file}")

def process_video_frame(frame, colour_corrector):
    corrected_frame = colour_corrector.process(frame)
    enhanced_frame = image_enhancement(corrected_frame)
    return cv2.cvtColor(enhanced_frame, cv2.COLOR_RGB2BGR)

def process_video(file, input_dir=CONFIG['input_dir'], output_dir=CONFIG['output_dir']):
    file_path = os.path.join(input_dir, file)
    prefix = file.split('.')[0]
    print(f'Processing video: {file}')
    cap = cv2.VideoCapture(file_path)
    if not cap.isOpened():
        print(f"Could not open video: {file}")
        return
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or CONFIG['output_video_fps']
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    temp_output_path = os.path.join(output_dir, CONFIG['temp_video_path'])
    final_output_path = os.path.join(output_dir, f'{prefix}_ColourCorrected.mp4')
    fourcc = cv2.VideoWriter_fourcc(*CONFIG['output_video_codec'])
    out = cv2.VideoWriter(temp_output_path, fourcc, fps, (width, height))
    if not out.isOpened():
        print(f"Could not create output video: {temp_output_path}")
        cap.release()
        return
    colour_corrector = ColourCorrection()
    print(f"Processing {frame_count} frames...")
    frame_idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        print(f"Processing frame {frame_idx + 1}/{frame_count}")
        processed_frame = process_video_frame(frame, colour_corrector)
        out.write(processed_frame)
        frame_idx += 1
    cap.release()
    out.release()
    try:
        print(f"Merging original audio into: {final_output_path}")
        input_video = ffmpeg.input(temp_output_path)
        input_audio = ffmpeg.input(file_path).audio
        output = ffmpeg.output(input_video.video, input_audio, final_output_path, vcodec='copy', acodec='copy', strict='experimental')
        ffmpeg.run(output, overwrite_output=True)
        print(f"Completed processing video with audio: {file}")
        if os.path.exists(temp_output_path):
            os.remove(temp_output_path)
    except ffmpeg.Error as e:
        print(f"Error merging audio: {e.stderr.decode()}")
        return

def main_color_correction():
    """Main function to process images and videos in the input directory for color correction."""
    os.makedirs(CONFIG['output_dir'], exist_ok=True)
    if not os.access(CONFIG['output_dir'], os.W_OK):
        print(f"No write permissions for directory {CONFIG['output_dir']}")
        exit(1)
    start_time = datetime.datetime.now()
    files = natsort.natsorted(os.listdir(CONFIG['input_dir']))
    files = [f for f in files if os.path.isfile(os.path.join(CONFIG['input_dir'], f))]
    image_files = [f for f in files if os.path.splitext(f)[1].lower() not in CONFIG['video_extensions']]
    video_files = [f for f in files if os.path.splitext(f)[1].lower() in CONFIG['video_extensions']]
    if image_files:
        print(f"Processing {len(image_files)} images...")
        with Pool(processes=cpu_count()) as pool:
            pool.map(process_image, image_files)
    if video_files:
        print(f"Processing {len(video_files)} videos...")
        for video_file in video_files:
            process_video(video_file)
    print(f'Total processing time: {datetime.datetime.now() - start_time}')

## Step 4: Run Color Correction

Execute the color correction pipeline to process images and videos in the `input/` directory. The processed files are saved in the `output-ip/` directory with a `_ColourCorrected` suffix.

In [4]:
main_color_correction()

Processing 2 images...


Processing image: file1.pngProcessing image: file2.png

Applying color correction...Applying color correction...

Applying image enhancement...
Applying image enhancement...
Completed processing image: file1.png
Completed processing image: file2.png
Total processing time: 0:00:02.232898


## Step 5: Detail Enhancement with Real-ESRGAN

This section applies the **Real-ESRGAN** model to enhance the details of color-corrected images. Real-ESRGAN uses a deep learning-based super-resolution technique to improve image sharpness and clarity.

**Note**: Ensure the pre-trained model `net_g_5000.pth` is present before running this cell.

In [6]:
import cv2
import glob
import os
import time
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

def main_esrgan():
    # Configuration
    input_path = 'output-ip'  # Input folder with color-corrected images
    output_path = 'output'  # Output folder for enhanced images
    model_name = 'RealESRGAN_x4plus'
    model_path = '../RealESRGAN/model/net_g_5000.pth'  # Path to pre-trained model
    outscale = 1  # Upsampling scale (1 for no upscaling, just enhancement)
    suffix = 'out'  # Suffix for enhanced images
    tile = 400  # Tile size for processing large images

    # Validate model name
    if model_name != 'RealESRGAN_x4plus':
        raise ValueError('This script only supports RealESRGAN_x4plus model')

    # Initialize Real-ESRGAN model
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    netscale = 4
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        model=model,
        tile=tile,
        tile_pad=10,
        pre_pad=0,
        half=False  # Use fp32 precision
    )

    # Create output directory
    os.makedirs(output_path, exist_ok=True)

    # Get input files
    if os.path.isfile(input_path):
        paths = [input_path]
    else:
        paths = sorted(glob.glob(os.path.join(input_path, '*')))

    # Start total execution timer
    total_start_time = time.time()

    # Process each image
    for idx, path in enumerate(paths):
        imgname, extension = os.path.splitext(os.path.basename(path))
        print(f'Processing {idx}: {imgname}')

        start_time = time.time()
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        if img is None:
            print(f'Failed to load image: {path}')
            continue

        img_mode = 'RGBA' if len(img.shape) == 3 and img.shape[2] == 4 else None
        try:
            output, _ = upsampler.enhance(img, outscale=outscale)
        except RuntimeError as error:
            print(f'Error processing {imgname}: {error}')
            print('Try reducing tile size if you encounter CUDA out of memory.')
            continue

        extension = extension[1:] if img_mode != 'RGBA' else 'png'
        save_path = os.path.join(output_path, f'{imgname}_{suffix}.{extension}')
        cv2.imwrite(save_path, output)

        end_time = time.time()
        processing_time = end_time - start_time
        print(f'Saved: {save_path}')
        print(f'Processing time for {imgname}: {processing_time:.2f} seconds')

    total_end_time = time.time()
    total_time = total_end_time - total_start_time
    print(f'Total execution time: {total_time:.2f} seconds')

## Step 6: Run Real-ESRGAN Enhancement

Execute the Real-ESRGAN pipeline to enhance the details of color-corrected images in the `output-ip/` directory. The enhanced images are saved in the `output/` directory with an `_out` suffix.

In [7]:
main_esrgan()

Processing 0: file1_ColourCorrected
	Tile 1/4
	Tile 2/4
	Tile 3/4
	Tile 4/4
Saved: output/file1_ColourCorrected_out.png
Processing time for file1_ColourCorrected: 180.66 seconds
Processing 1: file2_ColourCorrected
	Tile 1/6
	Tile 2/6
	Tile 3/6
	Tile 4/6
	Tile 5/6
	Tile 6/6
Saved: output/file2_ColourCorrected_out.png
Processing time for file2_ColourCorrected: 297.98 seconds
Total execution time: 478.64 seconds


## Step 7: Evaluate Image Quality

This section evaluates the quality of the processed images by comparing them to reference images in the `reference/` directory. It calculates two metrics:
- **PSNR (Peak Signal-to-Noise Ratio)**: Measures the quality of reconstruction (higher is better).
- **MSE (Mean Squared Error)**: Measures the average squared difference between images (lower is better).

Results are saved to `metrics_output_ip.txt` (for color-corrected images) and `metrics_output.txt` (for Real-ESRGAN enhanced images).

In [17]:
import os
import cv2
from skimage.metrics import peak_signal_noise_ratio, mean_squared_error

def compare_images(ref_folder, target_folder, suffix, output_file):
    with open(output_file, 'w') as f:
        f.write(f"{'File':<30} {'PSNR (dB)':<12} {'MSE':<10}\n")
        f.write('-' * 60 + '\n')
        ref_files = [f for f in os.listdir(ref_folder) if os.path.isfile(os.path.join(ref_folder, f))]
        for ref_file in ref_files:
            base_name, ext = os.path.splitext(ref_file)
            target_file = f"{base_name}{suffix}{ext}"
            target_path = os.path.join(target_folder, target_file)
            if os.path.exists(target_path):
                ref_img = cv2.imread(os.path.join(ref_folder, ref_file))
                target_img = cv2.imread(target_path)
                if ref_img is None or target_img is None:
                    f.write(f"{ref_file:<30} {'Error loading images':<12}\n")
                    continue
                try:
                    psnr = peak_signal_noise_ratio(ref_img, target_img)
                    mse = mean_squared_error(ref_img, target_img)
                    f.write(f"{ref_file:<30} {psnr:<12.2f} {mse:<10.2f}\n")
                except Exception as e:
                    f.write(f"{ref_file:<30} {'Error':<12} {str(e):<10}\n")
            else:
                f.write(f"{ref_file:<30} {'Target not found':<12}\n")

# Define folder paths and suffixes
reference_folder = 'reference/'
output_ip_folder = 'output-ip/'
output_folder = 'output/'
suffix_ip = '_ColourCorrected'
suffix_out = '_ColourCorrected_out'

# Compare images
compare_images(reference_folder, output_ip_folder, suffix_ip, 'metrics_output_ip.txt')
compare_images(reference_folder, output_folder, suffix_out, 'metrics_output.txt')

## Step 8: View Evaluation Results

Display the contents of the evaluation metrics files to review the PSNR and MSE values for both color-corrected and Real-ESRGAN-enhanced images.

In [18]:
!cat metrics_output_ip.txt
!echo "\n"
!cat metrics_output.txt

File                           PSNR (dB)    MSE       
------------------------------------------------------------
file2.png                      20.47        583.55    
file1.png                      21.01        515.22    
\n
File                           PSNR (dB)    MSE       
------------------------------------------------------------
file2.png                      20.58        569.05    
file1.png                      20.06        640.92    


## Conclusion

This notebook demonstrates a hybrid approach to underwater image enhancement by combining traditional color correction techniques with Real-ESRGAN for detail enhancement. The pipeline effectively addresses underwater imaging challenges such as color distortion and low contrast, producing visually appealing results. The evaluation metrics provide quantitative insights into the improvement achieved at each stage.

**Next Steps**:
- Experiment with different `CONFIG` parameters (e.g., `enhancement_strength`, `gimfilt_radius`) to optimize results for specific underwater conditions.
- Adjust the `outscale` parameter in Real-ESRGAN for upscaling if higher resolution is desired.