In [1]:
import os, gc, csv
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

from metrics_helpers import (
    read_image, pre_hdr_p3, align_hdr_pred_to_gt,
    psnr, vsi, piqe, lpips, hdr_vdp3, pu, reinhard_tonemap,
    initialize_fid, initialize_fvd,
    compute_fid, fid_update, compute_fvd, fvd_update, cvvdp, initialize_cvvdp
)

pred_dir = "/home/tedlasai/hdrvideo/evaluations/ours_stuttgart/under/"
gt_dir   = "/home/tedlasai/hdrvideo/evaluations/stuttgart/hdr/"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def describe_frame(x, name="frame"):
    if x is None:
        return f"{name}: None"
    if isinstance(x, torch.Tensor):
        x_np = x.detach().cpu().numpy()
        t = "torch.Tensor"
        dtype = str(x.dtype)
        device = str(x.device)
    else:
        x_np = np.asarray(x)
        t = type(x).__name__
        dtype = str(x_np.dtype)
        device = "cpu"
    return (
        f"{name}: type={t}, dtype={dtype}, device={device}, "
        f"shape={x_np.shape}, min={np.nanmin(x_np):.6g}, max={np.nanmax(x_np):.6g}, "
        f"mean={np.nanmean(x_np):.6g}, std={np.nanstd(x_np):.6g}, "
        f"nan={np.isnan(x_np).sum()}, inf={np.isinf(x_np).sum()}"
    )

def to_float(x):
    # LPIPS is often torch.Tensor; convert reliably to python float
    if isinstance(x, torch.Tensor):
        return float(x.detach().cpu().flatten()[0].item())
    if isinstance(x, np.ndarray):
        return float(x.flatten()[0])
    return float(x)

def assert_same_shapes(a, b, name_a="a", name_b="b"):
    a_np = a.detach().cpu().numpy() if isinstance(a, torch.Tensor) else np.asarray(a)
    b_np = b.detach().cpu().numpy() if isinstance(b, torch.Tensor) else np.asarray(b)
    assert a_np.shape == b_np.shape, f"Shape mismatch: {name_a}={a_np.shape} vs {name_b}={b_np.shape}"

def basic_sanity(pred, gt, tag=""):
    print(f"\n--- SANITY {tag} ---")
    print(describe_frame(pred, "pred"))
    print(describe_frame(gt, "gt"))
    assert_same_shapes(pred, gt, "pred", "gt")


In [3]:
video_paths = sorted(os.listdir(pred_dir))
print("Num videos:", len(video_paths))
print("First few:", video_paths[:3])

# Pick a single video to debug
video_path = video_paths[0]
pred_video_dir = os.path.join(pred_dir, video_path)
gt_video_dir   = os.path.join(gt_dir, video_path)
im_paths = sorted(os.listdir(pred_video_dir))[:16]

print("Debug video:", video_path)
print("Frames:", len(im_paths), "example:", im_paths[:3])

# Fresh metric state each run
pu_fvd_metric = initialize_fvd()

lpips_raw_list = []      # store raw output (might be tensor)
lpips_float_list = []    # store float versions
pu_norm_stats = []       # store denom + range info

# Collect frames used for FVD
pu_preds = []
pu_gts   = []

for idx, im_name in enumerate(im_paths):
    pred_im_path = os.path.join(pred_video_dir, im_name)
    gt_im_path   = os.path.join(gt_video_dir, im_name)

    cv2_hdr_pred = read_image(pred_im_path)
    cv2_hdr_gt   = read_image(gt_im_path)
    cv2_hdr_gt   = pre_hdr_p3(cv2_hdr_gt)
    cv2_hdr_pred, cv2_hdr_gt, _ = align_hdr_pred_to_gt(cv2_hdr_pred, cv2_hdr_gt)

    pu_pred, pu_gt = pu(cv2_hdr_pred), pu(cv2_hdr_gt)

    # IMPORTANT: you normalize by max(pu_gt). Let's inspect it.
    gt_max = np.max(pu_gt) if np.max(pu_gt) > 0 else 1.0
    pu_pred_norm = pu_pred / gt_max
    pu_gt_norm   = pu_gt   / gt_max

    pu_norm_stats.append({
        "idx": idx,
        "gt_max": float(gt_max),
        "pred_min": float(np.min(pu_pred_norm)),
        "pred_max": float(np.max(pu_pred_norm)),
        "gt_min": float(np.min(pu_gt_norm)),
        "gt_max": float(np.max(pu_gt_norm)),
    })

    # LPIPS is computed on tonemapped frames in your script
    reinhard_pred = reinhard_tonemap(cv2_hdr_pred)
    reinhard_gt   = reinhard_tonemap(cv2_hdr_gt)
    lp = lpips(reinhard_pred, reinhard_gt)

    lpips_raw_list.append(lp)
    lpips_float_list.append(to_float(lp))

    # Collect for PU-FVD
    pu_preds.append(pu_pred_norm)
    pu_gts.append(pu_gt_norm)

    if idx < 3:
        basic_sanity(pu_pred_norm, pu_gt_norm, tag=f"PU norm frame {idx}")
        print("LPIPS raw:", type(lp), lp)
        print("LPIPS float:", lpips_float_list[-1])

print("\nLPIPS types in list:", {type(x) for x in lpips_raw_list})
print("LPIPS float stats: mean/std/min/max =",
      float(np.mean(lpips_float_list)),
      float(np.std(lpips_float_list)),
      float(np.min(lpips_float_list)),
      float(np.max(lpips_float_list)))

print("\nPU norm denom stats (gt_max): min/median/max =",
      min(d["gt_max"] for d in pu_norm_stats),
      np.median([d["gt_max"] for d in pu_norm_stats]),
      max(d["gt_max"] for d in pu_norm_stats))
print("Example PU norm stats first 3:", pu_norm_stats[:3])


Num videos: 10
First few: ['bistro_01', 'bistro_02', 'bistro_03']
Debug video: bistro_01
Frames: 16 example: ['frame_0000.exr', 'frame_0001.exr', 'frame_0002.exr']
Loading videomae model ...

--- SANITY PU norm frame 0 ---
pred: type=ndarray, dtype=float32, device=cpu, shape=(704, 1280, 3), min=1.30219e-12, max=1, mean=0.160027, std=0.120588, nan=0, inf=0
gt: type=ndarray, dtype=float32, device=cpu, shape=(704, 1280, 3), min=0.0857099, max=1, mean=0.177455, std=0.142863, nan=0, inf=0
LPIPS raw: <class 'float'> 0.19627350568771362
LPIPS float: 0.19627350568771362

--- SANITY PU norm frame 1 ---
pred: type=ndarray, dtype=float32, device=cpu, shape=(704, 1280, 3), min=1.30219e-12, max=1, mean=0.159911, std=0.121087, nan=0, inf=0
gt: type=ndarray, dtype=float32, device=cpu, shape=(704, 1280, 3), min=0.0857267, max=1, mean=0.177433, std=0.142914, nan=0, inf=0
LPIPS raw: <class 'float'> 0.1977795958518982
LPIPS float: 0.1977795958518982

--- SANITY PU norm frame 2 ---
pred: type=ndarray, dty

In [4]:
# Check consistency of shapes/dtypes across time
shapes = [np.asarray(x).shape for x in pu_preds]
dtypes = [np.asarray(x).dtype for x in pu_preds]
print("Unique shapes:", sorted(set(shapes)))
print("Unique dtypes:", sorted(set(map(str, dtypes))))

# Also check channel count
example = np.asarray(pu_preds[0])
print("Example shape:", example.shape)

# If frames are HxWxC, verify C==3
if len(example.shape) == 3:
    print("Channels:", example.shape[-1])

# Run FVD update
fvd_update(pu_preds, pu_gts, pu_fvd_metric)
pu_fvd = compute_fvd(pu_fvd_metric)
print("PU-FVD (single video debug run):", float(pu_fvd))


Unique shapes: [(704, 1280, 3)]
Unique dtypes: ['float32']
Example shape: (704, 1280, 3)
Channels: 3
PU-FVD (single video debug run): 0.005932159731206265


  arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')


In [None]:
# current script

arr_raw = np.array(lpips_raw_list)
print("Raw np.array(lpips_raw_list) dtype:", arr_raw.dtype)
print("Raw array shape:", arr_raw.shape)

# new approach
arr_float = np.array(lpips_float_list, dtype=np.float32)
print("Float array dtype:", arr_float.dtype)
print("LPIPS std (float):", float(np.std(arr_float)))
print("LPIPS mean (float):", float(np.mean(arr_float)))


Raw np.array(lpips_raw_list) dtype: float64
Raw array shape: (16,)
Float array dtype: float32
LPIPS std (float): 0.002422226592898369
LPIPS mean (float): 0.1965004950761795
