<a href="https://colab.research.google.com/github/zhuzihan728/Image-Restore/blob/main/DehazeFormer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Setup
- First, in the **Runtime** menu -> **Change runtime type**, make sure to have ```Hardware Accelerator = GPU```
- Clone repo and install dependencies.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:

# git clone this repository
!git clone https://github.com/IDKiro/DehazeFormer.git
%cd DehazeFormer

# conda create -n pt1102 python=3.7
# conda activate pt1102

# conda install pytorch=1.10.2 torchvision torchaudio cudatoolkit=11.3 -c pytorch
!pip install opencv-python tqdm pytorch-msssim timm

Cloning into 'DehazeFormer'...
remote: Enumerating objects: 94, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 94 (delta 17), reused 11 (delta 11), pack-reused 70 (from 2)[K
Receiving objects: 100% (94/94), 768.88 KiB | 19.71 MiB/s, done.
Resolving deltas: 100% (43/43), done.
/content/DehazeFormer
Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0


# 2. Download Pre-trained Models


In [3]:

!cp -r /content/drive/MyDrive/Dehazeformer/saved_models /content/DehazeFormer/

# 3. Inference

in test.py Change line 54 to `imwrite(out, f'{args.output}/{i}', normalize=True, value_range=(0, 1))`


In [None]:
# !python test.py --embedder-model-path /content/drive/MyDrive/oneRestore/embedder_model.tar --restore-model-path /content/drive/MyDrive/oneRestore/onerestore_cdd-11.tar --input ./image/ --output ./output/ --concat

In [4]:
import torch
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage import img_as_ubyte
import cv2
import json
import os
from tqdm import tqdm
import numpy as np
from PIL import Image
from collections import OrderedDict

class EvalDataset:
    def __init__(self, corrupted_dir, original_dir, mask_dir, metadata_path,
                 im_size=None, transform=None):
        """
        :param corrupted_dir: path to corrupted images folder
        :param original_dir: path to original images folder
        :param mask_dir: path to mask images folder
        :param metadata_path: path to metadata.json
        :param im_size: target size (h, w) or None to keep original
        """
        self.corrupted_dir = corrupted_dir
        self.original_dir = original_dir
        self.mask_dir = mask_dir
        self.im_size = im_size
        self.transform = transform

        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        meta = self.metadata[idx]

        # Load images using names from metadata
        corrupted = Image.open(os.path.join(self.corrupted_dir, meta['corrupted_image']))
        original = Image.open(os.path.join(self.original_dir, meta['original_image']))

        # Resize if needed
        if self.im_size:
            corrupted = corrupted.resize(self.im_size, Image.Resampling.LANCZOS)
            original = original.resize(self.im_size, Image.Resampling.LANCZOS)

        if self.transform:
            corrupted = self.transform(corrupted)
            original = self.transform(original)

        return corrupted, original, meta


class DehazeFormerEvaluator:
    def __init__(self, model, eval_dataset, output_dir, task):
        self.model = model
        self.dataset = eval_dataset
        self.output_dir = f"{output_dir}_{task}"
        os.makedirs(os.path.join(self.output_dir, 'restored'), exist_ok=True)

    def rgb_to_y(self, img):
        """Convert RGB to Y channel - matches MATLAB rgb2ycbcr"""
        from skimage.color import rgb2ycbcr
        img_ycbcr = rgb2ycbcr(img)  # Expects float [0,1] or uint8 [0,255]
        return img_ycbcr[:, :, 0]

    def evaluate(self):
        results = []

        with torch.no_grad():
            for idx in tqdm(range(len(self.dataset))):
                torch.cuda.ipc_collect()
                torch.cuda.empty_cache()

                # Load images
                corrupted, original, meta = self.dataset[idx]

                # Convert PIL to numpy if needed
                if hasattr(corrupted, 'convert'):
                    corrupted_np = np.array(corrupted.convert('RGB'))
                    original_np = np.array(original.convert('RGB'))
                else:
                    corrupted_np = corrupted
                    original_np = original

                # Prepare input for DehazeFormer (expects range [-1, 1])
                input_ = torch.from_numpy(corrupted_np).float().div(255.).permute(2,0,1).unsqueeze(0).cuda()
                input_ = input_ * 2 - 1  # [0, 1] to [-1, 1]

                # Restore
                restored = self.model(input_).clamp(-1, 1)

                # [-1, 1] to [0, 1]
                restored = restored * 0.5 + 0.5

                # Convert to numpy
                restored_np = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
                restored_np = img_as_ubyte(restored_np[0])

                # Calculate RGB metrics
                psnr_rgb = peak_signal_noise_ratio(original_np, restored_np, data_range=255)
                ssim_rgb = structural_similarity(original_np, restored_np, channel_axis=2, data_range=255)
                mae = np.mean(np.abs(original_np.astype(float) - restored_np.astype(float)))
                mse = np.mean((original_np.astype(float) - restored_np.astype(float)) ** 2)

                # Calculate Y channel metrics
                original_y = self.rgb_to_y(original_np)
                restored_y = self.rgb_to_y(restored_np)
                psnr_y = peak_signal_noise_ratio(original_y, restored_y, data_range=255)
                ssim_y = structural_similarity(original_y, restored_y, data_range=255)

                # Save restored image
                filename = meta['corrupted_image'].replace('_alpha', '_restored_alpha')
                cv2.imwrite(
                    os.path.join(self.output_dir, 'restored', filename),
                    cv2.cvtColor(restored_np, cv2.COLOR_RGB2BGR)
                )

                # Store results
                result = {
                    **meta,
                    'psnr_rgb': float(psnr_rgb),
                    'ssim_rgb': float(ssim_rgb),
                    'psnr_y': float(psnr_y),
                    'ssim_y': float(ssim_y),
                    'mae': float(mae),
                    'mse': float(mse),
                    'restored_image': filename
                }
                results.append(result)

        # Save metrics
        with open(os.path.join(self.output_dir, 'eval_results.json'), 'w') as f:
            json.dump(results, f, indent=2)

        # Print summary
        print(f"\n=== Evaluation Results ===")
        print(f"Average PSNR (RGB): {np.mean([r['psnr_rgb'] for r in results]):.2f} dB")
        print(f"Average SSIM (RGB): {np.mean([r['ssim_rgb'] for r in results]):.4f}")
        print(f"Average PSNR (Y):   {np.mean([r['psnr_y'] for r in results]):.2f} dB")
        print(f"Average SSIM (Y):   {np.mean([r['ssim_y'] for r in results]):.4f}")
        print(f"Average MAE:        {np.mean([r['mae'] for r in results]):.2f}")
        print(f"Average MSE:        {np.mean([r['mse'] for r in results]):.2f}")

        avg_metrics = []
        alpha_ranges = list(set(str(r['alpha_range']) for r in results))
        alpha_ranges.sort()

        for alpha_range in alpha_ranges:
            alpha_results = [r for r in results if str(r['alpha_range']) == alpha_range]
            avg_metrics_alpha = {
                'alpha_range': alpha_range,
                'count': len(alpha_results),
                'avg_psnr (RGB)': float(np.mean([r['psnr_rgb'] for r in alpha_results])),
                'avg_ssim (RGB)': float(np.mean([r['ssim_rgb'] for r in alpha_results])),
                'avg_psnr (Y)': float(np.mean([r['psnr_y'] for r in alpha_results])),
                'avg_ssim (Y)': float(np.mean([r['ssim_y'] for r in alpha_results])),
                'avg_mae': float(np.mean([r['mae'] for r in alpha_results])),
                'avg_mse': float(np.mean([r['mse'] for r in alpha_results]))
            }
            avg_metrics.append(avg_metrics_alpha)
            print(f"======= Alpha {alpha_range} =======")
            print(f"PSNR (RGB): {avg_metrics_alpha['avg_psnr (RGB)']:.2f}")
            print(f"SSIM (RGB): {avg_metrics_alpha['avg_ssim (RGB)']:.4f}")
            print(f"PSNR (Y):   {avg_metrics_alpha['avg_psnr (Y)']:.2f}")
            print(f"SSIM (Y):   {avg_metrics_alpha['avg_ssim (Y)']:.4f}")
            print(f"MAE:        {avg_metrics_alpha['avg_mae']:.2f}")
            print(f"MSE:        {avg_metrics_alpha['avg_mse']:.2f}")

        avg_metrics_total = {
            'alpha_range': 'total',
            'count': len(results),
            'avg_psnr': float(np.mean([r['psnr_rgb'] for r in results])),
            'avg_ssim': float(np.mean([r['ssim_rgb'] for r in results])),
            'avg_psnr_y': float(np.mean([r['psnr_y'] for r in results])),
            'avg_ssim_y': float(np.mean([r['ssim_y'] for r in results])),
            'avg_mae': float(np.mean([r['mae'] for r in results])),
            'avg_mse': float(np.mean([r['mse'] for r in results]))
        }
        avg_metrics.append(avg_metrics_total)

        with open(os.path.join(self.output_dir, 'avg_metrics.json'), 'w') as f:
            json.dump(avg_metrics, f, indent=2)
        return results


# Helper function to load DehazeFormer model
def load_dehazeformer_model(model_name, checkpoint_path):
    """
    Load DehazeFormer model
    :param model_name: 'dehazeformer-s', 'dehazeformer-b', 'dehazeformer-l', etc.
    :param checkpoint_path: path to the .pth checkpoint file
    """
    import sys
    sys.path.append('/content/DehazeFormer')
    from models import dehazeformer_s, dehazeformer_b, dehazeformer_l

    # Create model
    model_dict = {
        'dehazeformer-s': dehazeformer_s,
        'dehazeformer-b': dehazeformer_b,
        'dehazeformer-l': dehazeformer_l
    }

    network = model_dict[model_name]()
    network.cuda()

    # Load checkpoint
    state_dict = torch.load(checkpoint_path)['state_dict']
    new_state_dict = OrderedDict()

    # Remove 'module.' prefix if present
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v

    network.load_state_dict(new_state_dict)
    network.eval()

    return network


model_name = 'dehazeformer-b'
# Load DehazeFormer model
model = load_dehazeformer_model(
    model_name=model_name,  # or 'dehazeformer-b', 'dehazeformer-l'
    checkpoint_path=f'/content/DehazeFormer/saved_models/reside6k/{model_name}.pth'
)

# Create evaluation dataset
eval_dataset = EvalDataset(
    corrupted_dir='/content/drive/MyDrive/eval_dataset/corrupted/',
    original_dir='/content/drive/MyDrive/image_test/',
    mask_dir='/content/drive/MyDrive/mask/',
    metadata_path='/content/drive/MyDrive/eval_dataset/metadata.json'
)

# Run evaluation
evaluator = DehazeFormerEvaluator(
    model=model,
    eval_dataset=eval_dataset,
    output_dir='/content/drive/MyDrive/eval_results',
    task=model_name
)

results = evaluator.evaluate()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 300/300 [10:04<00:00,  2.01s/it]


=== Evaluation Results ===
Average PSNR (RGB): 19.18 dB
Average SSIM (RGB): 0.7785
Average PSNR (Y):   20.81 dB
Average SSIM (Y):   0.8089
Average MAE:        30.58
Average MSE:        2176.23
PSNR (RGB): 26.63
SSIM (RGB): 0.9292
PSNR (Y):   28.31
SSIM (Y):   0.9448
MAE:        11.12
MSE:        267.47
PSNR (RGB): 17.53
SSIM (RGB): 0.7698
PSNR (Y):   19.14
SSIM (Y):   0.8051
MAE:        31.09
MSE:        1809.58
PSNR (RGB): 13.37
SSIM (RGB): 0.6364
PSNR (Y):   14.98
SSIM (Y):   0.6767
MAE:        49.52
MSE:        4451.65



