In [1]:
import os

import numpy as np
import torch
from os import path
import sys
import tqdm
import imageio.v3 as iio
import cv2 as cv

from torchmetrics.functional.image import peak_signal_noise_ratio, structural_similarity_index_measure
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

In [2]:
ssim_model = StructuralSimilarityIndexMeasure(data_range=1.0).cuda()
psnr_model = PeakSignalNoiseRatio(data_range=1.0).cuda()
lpips_model = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).cuda()

# code from AnimatableGaussians
def crop_image(gt_mask, patch_size, *args):
    """
    :param gt_mask: (H, W)
    :param patch_size: resize the cropped patch to the given patch_size
    :param args: some images with shape of (H, W, C)
    """
    mask_uv = np.argwhere(gt_mask > 0.)
    min_v, min_u = mask_uv.min(0)
    max_v, max_u = mask_uv.max(0)
    pad_size = 0
    min_v = (min_v - pad_size).clip(0, gt_mask.shape[0])
    min_u = (min_u - pad_size).clip(0, gt_mask.shape[1])
    max_v = (max_v + pad_size).clip(0, gt_mask.shape[0])
    max_u = (max_u + pad_size).clip(0, gt_mask.shape[1])
    len_v = max_v - min_v
    len_u = max_u - min_u
    max_size = max(len_v, len_u)

    cropped_images = []
    for image in args:
        if image is None:
            cropped_images.append(None)
        else:
            cropped_image = np.ones((max_size, max_size, 3), dtype = image.dtype)
            if len_v > len_u:
                start_u = (max_size - len_u) // 2
                cropped_image[:, start_u: start_u + len_u] = image[min_v: max_v, min_u: max_u]
            else:
                start_v = (max_size - len_v) // 2
                cropped_image[start_v: start_v + len_v, :] = image[min_v: max_v, min_u: max_u]

            cropped_image = cv.resize(cropped_image, (patch_size, patch_size), interpolation = cv.INTER_LINEAR)
            cropped_images.append(cropped_image)

    if len(cropped_images) > 1:
        return cropped_images
    else:
        return cropped_images[0]

In [None]:
data_dir = ''

import skimage.morphology as sm
from torch.utils.data import DataLoader

class MyDataset:
    def __init__(self, gt_paths, render_paths, mask_paths):
        self.gt_paths = gt_paths
        self.render_paths = render_paths
        self.mask_paths = mask_paths
    def __len__(self):
        return len(self.gt_paths)
    def __getitem__(self, idx):
        im_gt = iio.imread(self.gt_paths[idx])
        im_gsbody = iio.imread(self.render_paths[idx])
        mask = iio.imread(self.mask_paths[idx])
        return im_gt, im_gsbody, mask

filenames = os.listdir(path.join(data_dir, 'gt'))
filenames.sort()
gt_paths = [path.join(data_dir, f'gt/{filename}') for filename in filenames]
mask_paths = [path.join(data_dir, f'mask/{filename}') for filename in filenames]
render_names = os.listdir(path.join(data_dir, 'render'))
render_names.sort()
render_paths = [path.join(data_dir, f'render/{filename}') for filename in render_names]
dataset = MyDataset(gt_paths, render_paths, mask_paths)
dataloader = DataLoader(dataset=dataset, batch_size=None, num_workers=8, collate_fn=lambda x: x)

lpips_list = []
ssim_list = []
psnr_list = []

for data in tqdm.tqdm(dataloader):
    im_gt, im_gsbody, mask = data
    assert len(mask.shape) == 2
    mask = mask > 128

    im_gt_crop, im_gsbody_crop = crop_image(mask, 512, im_gt, im_gsbody)

    im_gt = torch.tensor(im_gt / 255).permute(2,0,1).float().cuda()[None]
    im_gsbody = torch.tensor(im_gsbody / 255).permute(2,0,1).float().cuda()[None]

    im_gt_crop = torch.tensor(im_gt_crop / 255).permute(2,0,1).float().cuda()[None]
    im_gsbody_crop = torch.tensor(im_gsbody_crop / 255).permute(2,0,1).float().cuda()[None]

    psnr_value = psnr_model(im_gsbody, im_gt).item()
    ssim_value = ssim_model(im_gsbody, im_gt).item()
    lpips_value = lpips_model(im_gsbody_crop, im_gt_crop).item()

    psnr_list.append(psnr_value)
    ssim_list.append(ssim_value)
    lpips_list.append(lpips_value)
    
psnr_list = np.array(psnr_list)
ssim_list = np.array(ssim_list)
lpips_list = np.array(lpips_list)

print(psnr_list.mean())
print(ssim_list.mean())
print(lpips_list.mean())
