In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import json
import torch
import numpy as np
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from plane_nerf.inerf_utils import transform_original_space_to_pose



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
os.chdir('/workspace/plane-nerf')
MODEL_PATH = "/workspace/plane-nerf/outputs/jackal_floor_training_data_1/plane-nerf/2024-03-11_145657"
DATA_PATH = "/workspace/plane-nerf/data/jackal_floor_evaluation_data"
BACKGROUND_IMG = "/workspace/plane-nerf/data/jackal_floor_training_data_1/background.png"
config_path = os.path.join(MODEL_PATH, "config.yml")
config, pipeline, _, _ = eval_setup(
                        Path(config_path),
                        test_mode="inference",
                    )
transform_file_path = "transforms.json"
with open(os.path.join(DATA_PATH, transform_file_path)) as f:
    transform = json.load(f)




In [4]:
#Open background_img
background_img = cv2.imread(BACKGROUND_IMG)
background_img = cv2.cvtColor(background_img, cv2.COLOR_BGR2RGB)

In [5]:
pipeline.eval()
pipeline.datamanager.setup_train()

for camera, batch in pipeline.datamanager.fixed_indices_train_dataloader:
    break

In [6]:
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

psnr = PeakSignalNoiseRatio(data_range=1.0).to(pipeline.device)
ssim = structural_similarity_index_measure
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to(pipeline.device)

In [7]:
store_results = []

for frame in transform["frames"]:
    print(frame["file_path"])
    img = cv2.imread(os.path.join(DATA_PATH, frame["file_path"]))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.tensor(img, dtype=torch.float32).to(pipeline.device) / 255.0
    
    tf = frame["transform_matrix"]

    tf = torch.tensor([tf[:3][:4]], dtype=torch.float32)
    tf = transform_original_space_to_pose(tf, 
                                          pipeline.datamanager.train_dataparser_outputs.dataparser_transform,
                                          pipeline.datamanager.train_dataparser_outputs.dataparser_scale,
                                          "opengl")
    camera.camera_to_worlds = tf.to(pipeline.device)
    outputs = pipeline.model.get_outputs_for_camera(camera=camera)   

    rendered_img = outputs["rgb"]
    
    mask = cv2.imread(os.path.join(DATA_PATH, frame["mask_path"]))

    #Open mask as binary
    mask = cv2.imread(os.path.join(DATA_PATH, frame["mask_path"]), cv2.IMREAD_GRAYSCALE)
    mask = mask > 0

    mask = torch.tensor(mask, dtype=torch.float32).to(pipeline.device)
    mask = torch.stack([mask, mask, mask], dim=-1)

    masked_img = rendered_img * mask

    #Add background
    background = torch.tensor(background_img, dtype=torch.float32).to(pipeline.device) / 255.0
    masked_img_with_background = masked_img + (1 - mask) * background

    psnr_full = psnr(img.permute(2,0,1).unsqueeze(0),masked_img_with_background.permute(2,0,1).unsqueeze(0))
    ssim_full = ssim(img.permute(2,0,1).unsqueeze(0),masked_img_with_background.permute(2,0,1).unsqueeze(0))
    lpips_full = lpips(img.permute(2,0,1).unsqueeze(0),masked_img_with_background.permute(2,0,1).unsqueeze(0))

    psnr_masked = psnr(img.permute(2,0,1).unsqueeze(0),masked_img.permute(2,0,1).unsqueeze(0))
    ssim_masked = ssim(img.permute(2,0,1).unsqueeze(0),masked_img.permute(2,0,1).unsqueeze(0))
    lpips_masked = lpips(img.permute(2,0,1).unsqueeze(0),masked_img.permute(2,0,1).unsqueeze(0))

    store_results.append([float(psnr_full), float(ssim_full), float(lpips_full), float(psnr_masked), float(ssim_masked), float(lpips_masked)])




In [10]:
#Save results as metrics.csv under DATA_PATH
store_results = np.array(store_results)
np.savetxt(os.path.join(MODEL_PATH, "metrics.csv"), store_results, delimiter=",")
print(np.mean(store_results, axis=0))

[3.62395821e+01 9.86207664e-01 7.78639037e-03 1.50880861e+01
 2.54929587e-02 8.55521262e-01]
