# Inference on Original Model

In [1]:
from google.colab import drive
import os

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
import numpy as np
import random

## Mount Google Drive
for sand storm images and checkpoints

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pwd

/content


## Clone Dehamer Git Repo and Import necessary files

In [4]:
!git clone https://github.com/Li-Chongyi/Dehamer.git

Cloning into 'Dehamer'...
remote: Enumerating objects: 126, done.[K
remote: Counting objects: 100% (59/59), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 126 (delta 23), reused 3 (delta 2), pack-reused 67[K
Receiving objects: 100% (126/126), 8.91 MiB | 16.44 MiB/s, done.
Resolving deltas: 100% (39/39), done.


In [5]:
# install packages
!pip install timm

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.16


In [6]:
# import libraries
%cd Dehamer/src
from swin_unet import UNet_emb
from utils import to_psnr, save_image
%cd ../..

/content/Dehamer/src
/content


## Model Testing

In [7]:
class SIEDataset(Dataset):
    """Ensure there is only related files/directories in the specified SIE dataset directory"""
    def __init__(self, dataset_dir):
        super().__init__()

        # get all image files in the directory

        self.ground_truth_images_dir = os.path.join(dataset_dir, "Ground_truth")
        self.sand_dust_images_dir = os.path.join(dataset_dir, "Sand_dust_images")

        self.sand_dust_image_names = [file_name for file_name in os.listdir(self.sand_dust_images_dir) if os.path.isfile(os.path.join(self.sand_dust_images_dir, file_name)) and os.path.splitext(file_name)[1] == ".jpg"]
        self.ground_truth_image_names = [file_name for file_name in os.listdir(self.ground_truth_images_dir) if os.path.isfile(os.path.join(self.ground_truth_images_dir, file_name)) and os.path.splitext(file_name)[1] == ".jpg"]

        if len(self.sand_dust_image_names) != len(self.ground_truth_image_names):
            raise ValueError("A number of sand-dust images and ground truth images must be the same")


        self.transform_input = Compose([ToTensor() , Normalize((0.64, 0.6, 0.58), (0.14,0.15, 0.152))])
        self.transform_gt = Compose([ToTensor()])

    def get_images(self, index):
        img_file_name = self.sand_dust_image_names[index]

        sand_dust_img = Image.open(os.path.join(self.sand_dust_images_dir, img_file_name))
        ground_truth_img = Image.open(os.path.join(self.ground_truth_images_dir, img_file_name))

        # ensure all images have the same size W_THRESHOLD and H_THRESHOLD
        if self.is_image_smaller_than_threshold(sand_dust_img, W_THRESHOLD, H_THRESHOLD):
            sand_dust_img = self.stretch_image(sand_dust_img, W_THRESHOLD, H_THRESHOLD)
        sand_dust_img = self.crop_image(sand_dust_img, W_THRESHOLD, H_THRESHOLD)

        if self.is_image_smaller_than_threshold(ground_truth_img, W_THRESHOLD, H_THRESHOLD):
            ground_truth_img = self.stretch_image(ground_truth_img, W_THRESHOLD, H_THRESHOLD)
        ground_truth_img = self.crop_image(ground_truth_img, W_THRESHOLD, H_THRESHOLD)


        # NOTE: the model only accepts width & height that is multiple of 16
        a = sand_dust_img.size
        a_0 = a[1] - np.mod(a[1],16)
        a_1 = a[0] - np.mod(a[0],16)
        sand_dust_img = sand_dust_img.crop((0, 0, 0 + a_1, 0+a_0))
        ground_truth_img = ground_truth_img.crop((0, 0, 0 + a_1, 0+a_0))

        sand_dust_img = self.transform_input(sand_dust_img)
        ground_truth_img = self.transform_gt(ground_truth_img)
        return sand_dust_img, ground_truth_img, img_file_name

    def crop_image(self, image, w_threshold, h_threshold):
        assert image.width >= w_threshold and image.height >= h_threshold, "to crop, image size must be bigger than or equal to the threshold values"

        # choose top and right randomly -> bottom and left automallycally determined
        top = random.randint(0, image.height - h_threshold)  # inclusive
        left = random.randint(0, image.width - w_threshold)

        bottom = top + h_threshold
        right = left + w_threshold

        return image.crop((left, top, right, bottom))


    def is_image_smaller_than_threshold(self, image, w_threshold, h_threshold) -> bool:
        return image.width < w_threshold or image.height < h_threshold

    def stretch_image(self, image, w_threshold, h_threshold):
        aspect_ratio = h_threshold / w_threshold

        if h_threshold - image.height < 0:
            resize_based_on_width = True
        elif w_threshold - image.width < 0:
            resize_based_on_width = False
        else:
            # resize based on whichever the difference is smaller
            resize_based_on_width = np.argmin([w_threshold - image.width, h_threshold - image.height])

        if resize_based_on_width:
            new_w = w_threshold
            new_h = int(new_w * aspect_ratio)
        else:
            new_h = h_threshold
            new_w = int(new_h / aspect_ratio)

        return image.resize((new_w, new_h))


    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.sand_dust_image_names)

In [8]:
def test(net, val_data_loader, device, category, save_tag=False):
    psnr_list = []

    for batch_id, val_data in enumerate(val_data_loader):

        with torch.no_grad():
            haze, gt, image_name = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            dehaze = net(haze)

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        # --- Save image --- #
        if save_tag:
            save_image(dehaze, image_name, category)

    avr_psnr = sum(psnr_list) / len(psnr_list)
    return avr_psnr

In [9]:
# constants
RUN_NAME = "inference1"
CHECKPOINT_DIR = f"/content/drive/MyDrive/FYP/Sem 2/4. Execution/checkpoints/original/dense/PSNR1662_SSIM05602.pt"
DATASET_DIR = "/content/drive/MyDrive/FYP/Sem 2/4. Execution/Datasets/Sanddust Database/SIE_Dataset/Synthetic_images"
BATCH_SIZE = 16
DATASET_NAME = "SIE_Dataset"

W_THRESHOLD, H_THRESHOLD = 440, 330

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
val_data_loader = DataLoader(SIEDataset(DATASET_DIR), batch_size=BATCH_SIZE, shuffle=False)

net = UNet_emb()
net = net.to(device)
net.load_state_dict(torch.load(CHECKPOINT_DIR), strict=False)
net.eval()

val_psnr = test(net, val_data_loader, device, DATASET_NAME, save_tag=True)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
