<a href="https://colab.research.google.com/github/zhuzihan728/Image-Restore/blob/main/restormer_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Restormer: Efficient Transformer for High-Resolution Image Restoration (CVPR 2022 -- Oral) [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2111.09881)

<hr />

This is a demo to run Restormer on you own images for the following tasks
- Real Image Denoising
- Single-Image Defocus Deblurring
- Single-Image Motion Deblurring
- Image Deraining


# 1. Setup
- First, in the **Runtime** menu -> **Change runtime type**, make sure to have ```Hardware Accelerator = GPU```
- Clone repo and install dependencies.


In [1]:
import os
!pip install einops

if os.path.isdir('Restormer'):
  !rm -r Restormer

# Clone Restormer
!git clone https://github.com/swz30/Restormer.git
%cd Restormer


Cloning into 'Restormer'...
remote: Enumerating objects: 312, done.[K
remote: Counting objects: 100% (115/115), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 312 (delta 74), reused 72 (delta 72), pack-reused 197 (from 2)[K
Receiving objects: 100% (312/312), 1.55 MiB | 1.72 MiB/s, done.
Resolving deltas: 100% (131/131), done.
/content/Restormer


# 2. Define Task and Download Pre-trained Models
Uncomment the task you would like to perform

In [50]:
task = 'Real_Denoising'
# task = 'Single_Image_Defocus_Deblurring'
# task = 'Motion_Deblurring'
# task = 'Deraining'

# Download the pre-trained models
if task == 'Real_Denoising' and not len(os.listdir('Denoising/pretrained_models')) >= 2:
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/real_denoising.pth -P Denoising/pretrained_models
if task == 'Single_Image_Defocus_Deblurring' and not len(os.listdir('Defocus_Deblurring/pretrained_models')) >= 2:
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/single_image_defocus_deblurring.pth -P Defocus_Deblurring/pretrained_models
if task == 'Motion_Deblurring' and not len(os.listdir('Motion_Deblurring/pretrained_models')) >= 2:
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth -P Motion_Deblurring/pretrained_models
if task == 'Deraining' and not len(os.listdir('Deraining/pretrained_models')) >= 2:
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth -P Deraining/pretrained_models


# 3. Upload Images
Either download the sample images or upload your own images

# 4. Prepare Model and Load Checkpoint

In [51]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse
import numpy as np

def get_weights_and_parameters(task, parameters):
    if task == 'Motion_Deblurring':
        weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
    elif task == 'Single_Image_Defocus_Deblurring':
        weights = os.path.join('Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
    elif task == 'Deraining':
        weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
    elif task == 'Real_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    return weights, parameters


# Get model weights and parameters
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)

load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)
model.cuda()

checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval()


Restormer(
  (patch_embed): OverlapPatchEmbed(
    (proj): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (encoder_level1): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): BiasFree_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (project_out): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): BiasFree_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(48, 254, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(254, 254, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=254, bias=False)
        (project_out): Conv2d(127, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
 

# 5. Inference

In [19]:
# input_dir = 'demo/sample_images/'+task+'/degraded'
# out_dir = 'demo/sample_images/'+task+'/restored'
# os.makedirs(out_dir, exist_ok=True)
# extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']
# files = natsorted(glob(os.path.join(input_dir, '*')))

# img_multiple_of = 8

# print(f"\n ==> Running {task} with weights {weights}\n ")
# with torch.no_grad():
#   for filepath in tqdm(files):
#       # print(file_)
#       torch.cuda.ipc_collect()
#       torch.cuda.empty_cache()
#       img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
#       input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).cuda()

#       # Pad the input if not_multiple_of 8
#       h,w = input_.shape[2], input_.shape[3]
#       H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
#       padh = H-h if h%img_multiple_of!=0 else 0
#       padw = W-w if w%img_multiple_of!=0 else 0
#       input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

#       restored = model(input_)
#       restored = torch.clamp(restored, 0, 1)

#       # Unpad the output
#       restored = restored[:,:,:h,:w]

#       restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
#       restored = img_as_ubyte(restored[0])

#       filename = os.path.split(filepath)[-1]
#       cv2.imwrite(os.path.join(out_dir, filename),cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [55]:
import torch
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage import img_as_ubyte
import cv2
import json
import os
from tqdm import tqdm
import numpy as np
from PIL import Image

class EvalDataset:
    def __init__(self, corrupted_dir, original_dir, mask_dir, metadata_path,
                 im_size=None, transform=None):
        """
        :param corrupted_dir: path to corrupted images folder
        :param original_dir: path to original images folder
        :param mask_dir: path to mask images folder
        :param metadata_path: path to metadata.json
        :param im_size: target size (h, w) or None to keep original
        """
        self.corrupted_dir = corrupted_dir
        self.original_dir = original_dir
        self.mask_dir = mask_dir
        self.im_size = im_size
        self.transform = transform

        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        meta = self.metadata[idx]

        # Load images using names from metadata
        corrupted = Image.open(os.path.join(self.corrupted_dir, meta['corrupted_image']))
        original = Image.open(os.path.join(self.original_dir, meta['original_image']))

        # Resize if needed
        if self.im_size:
            corrupted = corrupted.resize(self.im_size, Image.Resampling.LANCZOS)
            original = original.resize(self.im_size, Image.Resampling.LANCZOS)

        if self.transform:
            corrupted = self.transform(corrupted)
            original = self.transform(original)

        return corrupted, original, meta

# Usage in evaluation:
# eval_dataset = EvalDataset(
#     corrupted_dir='eval_dataset/corrupted/',
#     original_dir='images/',  # Original folder
#     mask_dir='backgrounds/',  # Mask folder
#     metadata_path='eval_dataset/metadata.json',
#     im_size=(256, 256)
# )
import torch
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage import img_as_ubyte
import cv2
import json
import os
from tqdm import tqdm
import numpy as np

class ImageRecoveryEvaluator:
    def __init__(self, model, eval_dataset, output_dir, task):
        self.model = model
        self.dataset = eval_dataset
        self.output_dir = f"{output_dir}_{task}"
        os.makedirs(os.path.join(self.output_dir, 'restored'), exist_ok=True)

    def rgb_to_y(self, img):
        """Convert RGB to Y channel - matches MATLAB rgb2ycbcr"""
        from skimage.color import rgb2ycbcr
        img_ycbcr = rgb2ycbcr(img)  # Expects float [0,1] or uint8 [0,255]
        return img_ycbcr[:, :, 0]

    def evaluate(self):
        results = []

        with torch.no_grad():
            for idx in tqdm(range(len(self.dataset))):
                torch.cuda.ipc_collect()
                torch.cuda.empty_cache()

                # Load images
                corrupted, original, meta = self.dataset[idx]

                # Convert PIL to numpy if needed
                if hasattr(corrupted, 'convert'):
                    corrupted_np = np.array(corrupted.convert('RGB'))
                    original_np = np.array(original.convert('RGB'))
                else:
                    corrupted_np = corrupted
                    original_np = original

                # Prepare input
                input_ = torch.from_numpy(corrupted_np).float().div(255.).permute(2,0,1).unsqueeze(0).cuda()

                # # Pad to multiple of 8
                # img_multiple_of = 8
                # h, w = input_.shape[2], input_.shape[3]
                # H = ((h + self.img_multiple_of) // self.img_multiple_of) * self.img_multiple_of
                # W = ((w + self.img_multiple_of) // self.img_multiple_of) * self.img_multiple_of
                # padh = H - h if h % self.img_multiple_of != 0 else 0
                # padw = W - w if w % self.img_multiple_of != 0 else 0
                # input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')

                # Restore
                restored = self.model(input_)
                restored = torch.clamp(restored, 0, 1)

                # Unpad
                # restored = restored[:, :, :h, :w]

                # Convert to numpy
                restored_np = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
                restored_np = img_as_ubyte(restored_np[0])

                # Calculate RGB metrics
                psnr_rgb = peak_signal_noise_ratio(original_np, restored_np, data_range=255)
                ssim_rgb = structural_similarity(original_np, restored_np, channel_axis=2, data_range=255)
                mae = np.mean(np.abs(original_np.astype(float) - restored_np.astype(float)))
                mse = np.mean((original_np.astype(float) - restored_np.astype(float)) ** 2)

                # Calculate Y channel metrics
                original_y = self.rgb_to_y(original_np)
                restored_y = self.rgb_to_y(restored_np)
                psnr_y = peak_signal_noise_ratio(original_y, restored_y, data_range=255)
                ssim_y = structural_similarity(original_y, restored_y, data_range=255)

                # Save restored image
                filename = meta['corrupted_image'].replace('_alpha', '_restored_alpha')
                cv2.imwrite(
                    os.path.join(self.output_dir, 'restored', filename),
                    cv2.cvtColor(restored_np, cv2.COLOR_RGB2BGR)
                )

                # Store results
                result = {
                    **meta,
                    'psnr_rgb': float(psnr_rgb),
                    'ssim_rgb': float(ssim_rgb),
                    'psnr_y': float(psnr_y),
                    'ssim_y': float(ssim_y),
                    'mae': float(mae),
                    'mse': float(mse),
                    'restored_image': filename
                }
                results.append(result)

        # Save metrics
        with open(os.path.join(self.output_dir, 'eval_results.json'), 'w') as f:
            json.dump(results, f, indent=2)

        # Print summary
        print(f"\n=== Evaluation Results ===")
        print(f"Average PSNR (RGB): {np.mean([r['psnr_rgb'] for r in results]):.2f} dB")
        print(f"Average SSIM (RGB): {np.mean([r['ssim_rgb'] for r in results]):.4f}")
        print(f"Average PSNR (Y):   {np.mean([r['psnr_y'] for r in results]):.2f} dB")
        print(f"Average SSIM (Y):   {np.mean([r['ssim_y'] for r in results]):.4f}")
        print(f"Average MAE:        {np.mean([r['mae'] for r in results]):.2f}")
        print(f"Average MSE:        {np.mean([r['mse'] for r in results]):.2f}")

        avg_metrics = []
        alpha_ranges = list(set(str(r['alpha_range']) for r in results))
        alpha_ranges.sort()

        for alpha_range in alpha_ranges:
            alpha_results = [r for r in results if str(r['alpha_range']) == alpha_range]
            avg_metrics_alpha = {
                'alpha_range': alpha_range,
                'count': len(alpha_results),
                'avg_psnr': float(np.mean([r['psnr'] for r in alpha_results])),
                'avg_ssim': float(np.mean([r['ssim'] for r in alpha_results])),
                'avg_mae': float(np.mean([r['mae'] for r in alpha_results])),
                'avg_mse': float(np.mean([r['mse'] for r in alpha_results]))
            }
            avg_metrics.append(avg_metrics_alpha)
            print(f"Alpha {alpha_range}: PSNR={avg_metrics_alpha['avg_psnr']:.2f}, SSIM={avg_metrics_alpha['avg_ssim']:.4f}")

        avg_metrics_total = {
            'alpha_range': 'total',
            'count': len(results),
            'avg_psnr': float(np.mean([r['psnr'] for r in results])),
            'avg_ssim': float(np.mean([r['ssim'] for r in results])),
            'avg_mae': float(np.mean([r['mae'] for r in results])),
            'avg_mse': float(np.mean([r['mse'] for r in results]))
        }
        avg_metrics.append(avg_metrics_total)

        with open(os.path.join(self.output_dir, 'avg_metrics.json'), 'w') as f:
            json.dump(avg_metrics, f, indent=2)
        return results

# Usage
eval_dataset = EvalDataset(
    corrupted_dir='/content/drive/MyDrive/eval_dataset/corrupted/',
    original_dir='/content/drive/MyDrive/image_test/',
    mask_dir='/content/drive/MyDrive/mask/',
    metadata_path='/content/drive/MyDrive/eval_dataset/metadata.json'
)

# # Test on a single image first
# corrupted, original, meta = eval_dataset[2]
# corrupted_np = np.array(corrupted.convert('RGB'))

# # Check input
# print(f"Input shape: {corrupted_np.shape}")
# print(f"Input range: [{corrupted_np.min()}, {corrupted_np.max()}]")

# # Run model
# input_ = torch.from_numpy(corrupted_np).float().div(255.).permute(2,0,1).unsqueeze(0).cuda()
# print(f"Model input shape: {input_.shape}")
# print(f"Model input range: [{input_.min()}, {input_.max()}]")

# with torch.no_grad():
#     restored = model(input_)
# restored = torch.clamp(restored, 0, 1)

# print(f"Model output shape: {restored.shape}")
# print(f"Model output range: [{restored.min()}, {restored.max()}]")

# # Visualize
# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# axes[0].imshow(corrupted_np)
# axes[0].set_title('Corrupted Input')
# axes[1].imshow(restored.permute(0,2,3,1).cpu().numpy()[0])
# axes[1].set_title('Model Output')
# axes[2].imshow(np.array(original.convert('RGB')))
# axes[2].set_title('Original')
# plt.show()


In [None]:
evaluator = ImageRecoveryEvaluator(
    model=model,
    eval_dataset=eval_dataset,
    output_dir='/content/drive/MyDrive/eval_results',
    task = task
)

results = evaluator.evaluate()

 88%|████████▊ | 264/300 [15:59<02:14,  3.75s/it]