# <center>This `.ipynb` file contains the code for computing the MS-SSIM metric</center>

### 1. Import the required libraries

In [1]:
import os
import cv2

import torch
from torchvision import transforms
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure as MS_SSIM

### 2. Custom function for calculating the `MS-SSIM` metric

In [2]:
def compute_dataset_msssim(original_dir, generated_dir):
    
    transform = transforms.Compose([transforms.ToTensor()])

    def filter_images(directory):
        return sorted(
            [f for f in os.listdir(directory) 
             if f.lower().endswith(('.jpg')) and ".ipynb_checkpoints" not in f]
        )

    original_images = filter_images(original_dir)
    generated_images = filter_images(generated_dir)

    if len(original_images) != len(generated_images):
        raise ValueError("Number of images in original and generated datasets do not match.")

    ms_ssim_metric = MS_SSIM(data_range=1.0)  # images are normalized within 0 to 1

    total_msssim = 0
    count = 0

    for orig_img_name, gen_img_name in zip(original_images, generated_images):
        orig_img_path = os.path.join(original_dir, orig_img_name)
        gen_img_path = os.path.join(generated_dir, gen_img_name)

        orig_img = cv2.imread(orig_img_path, cv2.IMREAD_COLOR)
        gen_img = cv2.imread(gen_img_path, cv2.IMREAD_COLOR)

        orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
        gen_img = cv2.cvtColor(gen_img, cv2.COLOR_BGR2RGB)

        if orig_img.shape != gen_img.shape:
            gen_img = cv2.resize(gen_img, (orig_img.shape[1], orig_img.shape[0]))

        orig_tensor = transform(orig_img).unsqueeze(0)  # add batch dimension
        gen_tensor = transform(gen_img).unsqueeze(0)    # add batch dimension

        msssim = ms_ssim_metric(orig_tensor, gen_tensor).item()
        total_msssim += msssim
        count += 1

    avg_msssim = total_msssim / count if count > 0 else 0
    return avg_msssim

### 3. Show the result

In [3]:
original_dir = '../org'
print('Total images in original_dir:', len(next(os.walk(original_dir))[2]))

generated_dir =  'gen_14'
print('Total images in generated_dir:', len(next(os.walk(generated_dir))[2]))

avg_msssim_score = compute_dataset_msssim(original_dir, generated_dir)
print(f"Average MS-SSIM: {avg_msssim_score:.4f}")

Total images in original_dir: 10000
Total images in generated_dir: 10000
Average MS-SSIM: 0.2196
