In [None]:
# %% [markdown]
# # üîç Evaluate 3√ó3 Grid Image Quality
# 
# This notebook compares generated 3√ó3 grid images against their GT grids.
# 
# ‚úÖ Supported modes:
# - **strict** ‚Üí skip if GT/GEN sizes differ  
# - **resize_gen_to_gt** ‚Üí resize GEN to GT's size before cutting  
# 
# Output metrics: FVD, SSIM, PSNR, LPIPS, MSE

# %%
import json
from pathlib import Path
from typing import List, Tuple
import numpy as np
from PIL import Image
import torch




In [None]:

# ==== ‚¨áÔ∏è ÈÖçÁΩÆÂå∫ ====
CONFIG = {
    "jsonl_path": "/root/PhotoDoodle/data/bridge_test/text_test_index.jsonl",  # ÂéüÂßã JSONL
    "output_dir": "inference/outputs_bridgeV2_text",                           # Êé®ÁêÜÁªìÊûúÊñá‰ª∂Â§π
    "fit_mode": "resize_gen_to_gt",    # "strict" Êàñ "resize_gen_to_gt"
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "only_final": False                # Ëã• TrueÔºåÂàôÂè™ÊØîËæÉÊúÄÂêé‰∏ÄÂ∏ß
}


In [None]:

# === ÂØºÂÖ•‰Ω†Â∑≤ÊúâÁöÑÊåáÊ†áÂáΩÊï∞ ===
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips
from calculate_mse import calculate_mse

In [None]:
# %%
# ==== Â∑•ÂÖ∑ÂáΩÊï∞ ====
def resolve_path_maybe_relative(p: str, base: Path) -> Path:
    pp = Path(p)
    return pp if pp.is_absolute() else (base / pp).resolve()

def _assert_divisible_by_3(W: int, H: int, tag: str, path: Path):
    assert W % 3 == 0 and H % 3 == 0, f"[{tag}] Grid size not divisible by 3: {path} ({W}x{H})"

def _cut_3x3_to_video(im: Image.Image) -> torch.Tensor:
    """Áõ¥Êé•Êåâ3√ó3ÂàáÂâ≤ÔºåÈ°∫Â∫è‰∏∫ 123 / 654 / 789"""
    W, H = im.size
    _assert_divisible_by_3(W, H, "CUT", Path("<in-memory>"))
    tile_w, tile_h = W // 3, H // 3
    boxes = [(c*tile_w, r*tile_h, (c+1)*tile_w, (r+1)*tile_h)
             for r in range(3) for c in range(3)]
    order = [0, 1, 2, 5, 4, 3, 6, 7, 8]

    frames = []
    for idx in order:
        patch = im.crop(boxes[idx])
        arr = np.asarray(patch, dtype=np.uint8)
        ten = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
        frames.append(ten)
    return torch.stack(frames, dim=0)  # [9,3,H,W]

def cut_grid_as_video_with_fit(gt_path: Path, gen_path: Path, fit_mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
    gt_im = Image.open(gt_path).convert("RGB")
    Wg, Hg = gt_im.size
    _assert_divisible_by_3(Wg, Hg, "GT", gt_path)

    gen_im = Image.open(gen_path).convert("RGB")
    Wp, Hp = gen_im.size
    _assert_divisible_by_3(Wp, Hp, "GEN", gen_path)

    if fit_mode == "strict":
        assert (Wp, Hp) == (Wg, Hg), f"size mismatch GT({Wg}x{Hg}) vs GEN({Wp}x{Hp})"
    elif fit_mode == "resize_gen_to_gt" and (Wp, Hp) != (Wg, Hg):
        gen_im = gen_im.resize((Wg, Hg), Image.BICUBIC)

    v_gt, v_gen = _cut_3x3_to_video(gt_im), _cut_3x3_to_video(gen_im)
    gt_im.close(); gen_im.close()
    return v_gt, v_gen

# %%
# ==== Âä†ËΩΩÊ†∑Êú¨ÂØπ ====
cfg = CONFIG
JSONL_PATH   = Path(cfg["jsonl_path"]).resolve()
JSONL_BASE   = JSONL_PATH.parent
OUTPUT_DIR   = Path(cfg["output_dir"]).resolve()
RESULTS_PATH = OUTPUT_DIR / "results.jsonl"

pairs: List[Tuple[Path, Path, dict]] = []
n_total = 0
n_ok = 0

assert RESULTS_PATH.exists(), f"results.jsonl not found: {RESULTS_PATH}"
with open(RESULTS_PATH, "r", encoding="utf-8") as fr:
    for line in fr:
        n_total += 1
        try:
            rec = json.loads(line)
        except Exception:
            continue
        if not rec.get("ok", False):
            continue

        out_rel = rec.get("output")
        if not out_rel:
            continue
        gen_path = (OUTPUT_DIR / out_rel).resolve()

        gt_raw = rec.get("target", None) or rec.get("gt", None)
        if not gt_raw:
            continue
        gt_path = resolve_path_maybe_relative(gt_raw, JSONL_BASE)

        if not gen_path.exists() or not gt_path.exists():
            continue

        pairs.append((gt_path, gen_path, rec))
        n_ok += 1

print(f"[PAIRING] total_lines={n_total}, usable_pairs={n_ok}")

# %%
# ==== ÊûÑÂª∫ËßÜÈ¢ëÂº†Èáè ====
videos_gt_list, videos_gen_list = [], []
FIT_MODE = cfg["fit_mode"]

for gt_path, gen_path, rec in pairs:
    try:
        v_gt, v_gen = cut_grid_as_video_with_fit(gt_path, gen_path, FIT_MODE)
        if v_gt.shape != v_gen.shape:
            continue
        videos_gt_list.append(v_gt)
        videos_gen_list.append(v_gen)
    except Exception as e:
        # print(f"[SKIP] {e}")
        continue

assert len(videos_gt_list) > 0, "No valid pairs after processing."

videos_gt  = torch.stack(videos_gt_list,  dim=0)
videos_gen = torch.stack(videos_gen_list, dim=0)
NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, H, W = videos_gt.shape
print(f"[DATA] videos_gt={videos_gt.shape}, videos_gen={videos_gen.shape}, FIT_MODE={FIT_MODE}")



In [None]:
import json as pyjson
from datetime import datetime

DEVICE = torch.device(cfg["device"])
ONLY_FINAL = cfg["only_final"]

# GPU tensors for FVD / LPIPS
v1_gpu = videos_gt.to(DEVICE, non_blocking=True)
v2_gpu = videos_gen.to(DEVICE, non_blocking=True)

# CPU tensors for SSIM / PSNR / MSE (their implementations call .numpy())
v1_cpu = videos_gt.cpu()
v2_cpu = videos_gen.cpu()

# ---- compute ----
metrics = {}
# GPU-heavy
metrics["fvd"]   = float(calculate_fvd(v1_gpu, v2_gpu, DEVICE, method='styleganv', only_final=ONLY_FINAL))
metrics["lpips"] = float(calculate_lpips(v1_gpu, v2_gpu, DEVICE, only_final=ONLY_FINAL))
# CPU / numpy
metrics["ssim"]  = float(calculate_ssim(v1_cpu, v2_cpu, only_final=ONLY_FINAL))
metrics["psnr"]  = float(calculate_psnr(v1_cpu, v2_cpu, only_final=ONLY_FINAL))
metrics["mse"]   = float(calculate_mse(v1_cpu, v2_cpu, only_final=ONLY_FINAL))

# ---- meta (flattened) ----
N, T, C, H, W = videos_gt.shape
flat_record = {
    "timestamp": datetime.now().isoformat(timespec="seconds"),
    "fit_mode": str(FIT_MODE),
    "only_final": bool(ONLY_FINAL),
    "num_pairs": int(N),
    "video_length": int(T),
    "channels": int(C),
    "height": int(H),
    "width": int(W),
    # metrics flattened:
    **metrics,
}

print(pyjson.dumps(flat_record, indent=2))

# ---- save flattened json ----
metrics_path = OUTPUT_DIR / f"metrics_fit_{FIT_MODE}_flat.json"
with open(metrics_path, "w", encoding="utf-8") as f:
    f.write(pyjson.dumps(flat_record, indent=2))
print(f"[OK] flattened metrics saved to: {metrics_path}")