# CE-VAE: Underwater Image Enhancement with Custom Datasets

This notebook allows you to use the CE-VAE model for underwater image enhancement with your own datasets.

**Before running:**
1. Go to Runtime → Change runtime type → Select GPU
2. Prepare your dataset with the folder structure described below

## 1. Setup Environment

In [None]:
# Clone the repository
!git clone https://github.com/priyanshuharshbodhi1/ce-vae-underwater-image-enhancement.git
%cd ce-vae-underwater-image-enhancement

In [None]:
# Install dependencies
!pip install -r requirements.txt -q

## 2. Mount Google Drive (for dataset access)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 3. Dataset Preparation

Your dataset should have this structure:
```
your_dataset/
├── train/
│   ├── GT/      # Ground truth images
│   └── input/   # Degraded images
└── val/
    ├── GT/
    └── input/
```

**Set your dataset path below:**

In [None]:
# ====== CONFIGURE YOUR DATASET PATH HERE ======
# Option 1: Dataset in Google Drive
DATASET_PATH = "/content/drive/MyDrive/my_underwater_dataset"

# Option 2: Upload directly (uncomment and run separately)
# from google.colab import files
# uploaded = files.upload()  # Upload a zip file
# !unzip your_dataset.zip -d /content/dataset
# DATASET_PATH = "/content/dataset"

# Verify the dataset structure
import os
print(f"Dataset path: {DATASET_PATH}")
if os.path.exists(DATASET_PATH):
    print("\nDataset structure:")
    for root, dirs, files in os.walk(DATASET_PATH):
        level = root.replace(DATASET_PATH, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        if level < 2:  # Only show first 2 levels
            subindent = ' ' * 2 * (level + 1)
            file_count = len(files)
            if file_count > 0:
                print(f"{subindent}({file_count} files)")
else:
    print("ERROR: Dataset path does not exist!")

In [None]:
# Generate dataset text files
!bash scripts/generate_dataset_txt.sh "{DATASET_PATH}"

# Check generated files
!echo "Generated files:"
!ls -la data/*.txt 2>/dev/null || echo "No text files found. Make sure your dataset has the correct structure."

## 4. Download Pre-trained Model

In [None]:
# Download ImageNet pre-trained weights (for training from scratch)
# You need to get the download link from the README and use gdown or wget

!mkdir -p data

# Option 1: If you have the file in Google Drive
# !cp "/content/drive/MyDrive/imagenet-pre-trained-cevae.ckpt" data/

# Option 2: Download using gdown (if you have the Google Drive file ID)
# !pip install gdown -q
# !gdown "YOUR_GDRIVE_FILE_ID" -O data/imagenet-pre-trained-cevae.ckpt

# Option 3: Download LSUI pre-trained model for inference
# !gdown "YOUR_LSUI_CHECKPOINT_ID" -O data/lsui-pretrained.ckpt

print("Available checkpoints in data/:")
!ls -la data/*.ckpt 2>/dev/null || echo "No checkpoints found. Please download the pre-trained model."

## 5. Create Custom Config (if training)

In [None]:
# Create a custom config for your dataset
custom_config = """
#Vector Capsule VAE - Custom Dataset Config
model:
  target: src.models.cevae.CEVAE
  params:
    discriminator: False
    ckpt_path: data/imagenet-pre-trained-cevae.ckpt
    embed_dim: 256
    ddconfig:
      double_z: False
      z_channels: 256
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult: [1, 1, 2, 2, 4]
      num_res_blocks: 2
      attn_resolutions: [16]
      dropout: 0.0

    lossconfig:
      target: src.modules.losses.combined.ReconstructionLossWithDiscriminator
      params:
        pixelloss_weight: 10.0
        perceptual_weight: 1.0
        gdl_loss_weight: 0.0
        color_loss_weight: 0.0
        ssim_loss_weight: 1.0
        disc_enabled: False

    optimizer:
      base_learning_rate: 4.5e-6

lightning:
  trainer:
    max_epochs: 100  # Adjust as needed
    accelerator: gpu
    devices: 1
    check_val_every_n_epoch: 10

data:
  target: src.data.dataset_wrapper.DataModuleFromConfig
  params:
    dataset_name: "CustomDataset"
    train_batch_size: 4  # Reduce if OOM
    val_batch_size: 8
    num_workers: 4
    train:
      target: src.data.image_enhancement.DatasetTrainFromImageFileList
      params:
        training_images_list_file: data/LSUI_train_input.txt
        target_images_list_file: data/LSUI_train_target.txt
        random_crop: True
        random_flip: True
        color_jitter:
          brightness: [0.9, 1.1]
          contrast: [0.9, 1.1]
          saturation: [0.9, 1.1]
          hue: [-0.02, 0.02]
        max_size: 288
        size: 256
    validation:
      target: src.data.image_enhancement.DatasetTestFromImageFileList
      params:
        test_images_list_file: data/LSUI_val_input.txt
        test_target_images_list_file: data/LSUI_val_target.txt
        size: 256
    test:
      target: src.data.image_enhancement.DatasetTestFromImageFileList
      params:
        test_images_list_file: data/LSUI_val_input.txt
        test_target_images_list_file: data/LSUI_val_target.txt
        size: 256
"""

with open('configs/cevae_custom.yaml', 'w') as f:
    f.write(custom_config)

print("Custom config created: configs/cevae_custom.yaml")

## 6. Training (Optional)

Run this if you want to train on your dataset. Skip to section 7 for inference only.

In [None]:
# Set wandb to offline mode (or login with: wandb login)
!wandb offline

# Train the model
!python main.py --config configs/cevae_custom.yaml

## 7. Inference: Enhance Your Images

In [None]:
# ====== CONFIGURE PATHS ======
# Path to the checkpoint (use pre-trained or your trained model)
CHECKPOINT_PATH = "data/lsui-pretrained.ckpt"  # Change this to your checkpoint

# Path to images you want to enhance
INPUT_IMAGES_PATH = f"{DATASET_PATH}/val/input"  # or any folder with images

# Where to save enhanced images
OUTPUT_PATH = "/content/enhanced_images"

print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Input images: {INPUT_IMAGES_PATH}")
print(f"Output folder: {OUTPUT_PATH}")

In [None]:
# Run inference
!python test.py \
    --config configs/cevae_E2E_lsui.yaml \
    --checkpoint "{CHECKPOINT_PATH}" \
    --data-path "{INPUT_IMAGES_PATH}" \
    --output-path "{OUTPUT_PATH}" \
    --batch-size 4 \
    --device cuda:0

## 8. View Results

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import random

# Get list of enhanced images
enhanced_images = [f for f in os.listdir(OUTPUT_PATH) if f.endswith(('.png', '.jpg', '.jpeg'))]

# Display up to 4 random samples
num_samples = min(4, len(enhanced_images))
samples = random.sample(enhanced_images, num_samples)

fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
if num_samples == 1:
    axes = [axes]

for idx, img_name in enumerate(samples):
    # Enhanced image
    enhanced = Image.open(os.path.join(OUTPUT_PATH, img_name))
    
    # Try to find corresponding input image
    input_path = os.path.join(INPUT_IMAGES_PATH, img_name)
    if os.path.exists(input_path):
        original = Image.open(input_path)
        axes[idx][0].imshow(original)
        axes[idx][0].set_title(f'Input: {img_name}')
    else:
        axes[idx][0].text(0.5, 0.5, 'Input not found', ha='center')
        axes[idx][0].set_title('Input')
    axes[idx][0].axis('off')
    
    axes[idx][1].imshow(enhanced)
    axes[idx][1].set_title(f'Enhanced: {img_name}')
    axes[idx][1].axis('off')

plt.tight_layout()
plt.show()

print(f"\nTotal enhanced images: {len(enhanced_images)}")
print(f"Saved to: {OUTPUT_PATH}")

## 9. Compute Quality Metrics

In [None]:
import numpy as np
from PIL import Image
import sys
sys.path.append('.')
from src.metrics import compute
import os

# Path to ground truth images (for reference metrics)
GT_PATH = f"{DATASET_PATH}/val/GT"  # Set to None if no GT available

# Compute metrics for all enhanced images
all_metrics = []
enhanced_images = [f for f in os.listdir(OUTPUT_PATH) if f.endswith(('.png', '.jpg', '.jpeg'))]

for img_name in enhanced_images[:10]:  # Limit to first 10 for speed
    enhanced_path = os.path.join(OUTPUT_PATH, img_name)
    enhanced_img = np.array(Image.open(enhanced_path).convert('RGB'))
    
    # Load GT if available
    gt_img = None
    if GT_PATH and os.path.exists(os.path.join(GT_PATH, img_name)):
        gt_img = np.array(Image.open(os.path.join(GT_PATH, img_name)).convert('RGB'))
        # Resize GT to match enhanced if needed
        if gt_img.shape != enhanced_img.shape:
            from PIL import Image as PILImage
            gt_img = np.array(PILImage.fromarray(gt_img).resize((enhanced_img.shape[1], enhanced_img.shape[0])))
    
    metrics = compute(enhanced_img, gt_img, gt_metrics=False)
    all_metrics.append(metrics)
    print(f"{img_name}: PSNR={metrics['psnr']:.2f}, SSIM={metrics['ssim']:.4f}, UIQM={metrics['uiqm']:.4f}, UCIQE={metrics['uciqe']:.4f}")

# Average metrics
if all_metrics:
    print("\n=== Average Metrics ===")
    for key in ['psnr', 'ssim', 'uiqm', 'uciqe', 'niqe']:
        values = [m[key] for m in all_metrics if m[key] != -1]
        if values:
            print(f"{key.upper()}: {np.mean(values):.4f} ± {np.std(values):.4f}")

## 10. Download Results

In [None]:
# Zip and download enhanced images
!zip -r enhanced_images.zip "{OUTPUT_PATH}"

from google.colab import files
files.download('enhanced_images.zip')

In [None]:
# Or copy to Google Drive
!cp -r "{OUTPUT_PATH}" "/content/drive/MyDrive/enhanced_underwater_images"
print("Copied to Google Drive!")