In [1]:
# %% [markdown]
# # 3×3 Grid 合成图 质量评估 Notebook
# 使用 tile-wise 适配（避免整图 resize 跨格模糊） + 扁平化指标记录

# %%
import json
from pathlib import Path
from typing import List, Tuple
from datetime import datetime

import numpy as np
from PIL import Image
import torch
import tqdm

# ==== 配置区 ====
CONFIG = {
    "jsonl_path": "/root/PhotoDoodle/data/bridge_test/traj_test_index.jsonl",  # 原始 JSONL 路径
    "output_dir": "inference/outputs_bridgeV2_traj_noop",                            # 生成结果文件夹
    "tile_fit_mode": "tile_gt_to_gen",    # 模式: "tile_strict" | "tile_gen_to_gt" | "tile_gt_to_gen"
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "only_final": False                    # 是否只比较最后一帧
}

# 导入指标函数（来自 common_metrics_on_video_quality）
from calculate_fvd   import calculate_fvd
from calculate_psnr  import calculate_psnr
from calculate_ssim  import calculate_ssim
from calculate_lpips import calculate_lpips
from calculate_mse   import calculate_mse

# %%
# ==== 工具函数 ====
def resolve_path_maybe_relative(p: str, base: Path) -> Path:
    pp = Path(p)
    return pp if pp.is_absolute() else (base / pp).resolve()

def _grid_splits(length: int) -> List[int]:
    edges = np.linspace(0, length, 4)
    return [int(round(x)) for x in edges]

def _cut_3x3_tiles(im: Image.Image) -> List[Image.Image]:
    W, H = im.size
    xs = _grid_splits(W)
    ys = _grid_splits(H)
    boxes_linear = [(xs[c], ys[r], xs[c+1], ys[r+1]) for r in range(3) for c in range(3)]
    order = [0, 1, 2, 5, 4, 3, 6, 7, 8]
    return [im.crop(boxes_linear[idx]) for idx in order]

def _tile_sizes(tiles: List[Image.Image]) -> List[Tuple[int,int]]:
    return [t.size for t in tiles]

def _resize_tiles_to(tiles: List[Image.Image], target_sizes: List[Tuple[int,int]]) -> List[Image.Image]:
    out = []
    for t, (tw, th) in zip(tiles, target_sizes):
        out.append(t if t.size==(tw,th) else t.resize((tw, th), Image.BICUBIC))
    return out

def _tiles_to_tensor(tiles: List[Image.Image]) -> torch.Tensor:
    frames = []
    for t in tiles:
        arr = np.asarray(t, dtype=np.uint8)
        ten = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
        frames.append(ten)
    return torch.stack(frames, dim=0)   # [T=9, C=3, H, W]

def cut_grid_as_video_tilefit(gt_path: Path, gen_path: Path, tile_fit_mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
    gt_im  = Image.open(gt_path).convert("RGB")
    gen_im = Image.open(gen_path).convert("RGB")

    gt_tiles  = _cut_3x3_tiles(gt_im)
    gen_tiles = _cut_3x3_tiles(gen_im)

    if tile_fit_mode == "tile_strict":
        if _tile_sizes(gt_tiles) != _tile_sizes(gen_tiles):
            gt_im.close(); gen_im.close()
            raise AssertionError(f"[tile_strict] per-tile sizes mismatch: GT={_tile_sizes(gt_tiles)} vs GEN={_tile_sizes(gen_tiles)}")

    elif tile_fit_mode == "tile_gen_to_gt":
        target_sizes = _tile_sizes(gt_tiles)
        gen_tiles = _resize_tiles_to(gen_tiles, target_sizes)

    elif tile_fit_mode == "tile_gt_to_gen":
        target_sizes = _tile_sizes(gen_tiles)
        gt_tiles = _resize_tiles_to(gt_tiles, target_sizes)

    else:
        gt_im.close(); gen_im.close()
        raise ValueError(f"Unknown tile_fit_mode: {tile_fit_mode}")

    v_gt  = _tiles_to_tensor(gt_tiles)
    v_gen = _tiles_to_tensor(gen_tiles)

    gt_im.close(); gen_im.close()
    return v_gt, v_gen



Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [on]
Loading model from: /root/PhotoDoodle/lpips/weights/v0.1/alex.pth


In [2]:
# %%
# ==== 收集配对 ====
cfg        = CONFIG
JSONL_PATH  = Path(cfg["jsonl_path"]).resolve()
JSONL_BASE  = JSONL_PATH.parent
OUTPUT_DIR  = Path(cfg["output_dir"]).resolve()
RESULTS_PATH= OUTPUT_DIR / "results.jsonl"

pairs: List[Tuple[Path, Path, dict]] = []
n_total = n_usable = n_no_target = n_missing = 0

assert RESULTS_PATH.exists(), f"results.jsonl not found: {RESULTS_PATH}"
with open(RESULTS_PATH, "r", encoding="utf-8") as fr:
    for line in fr:
        n_total += 1
        try:
            rec = json.loads(line)
        except Exception:
            continue
        if not rec.get("ok", False):
            continue
        out_rel = rec.get("output")
        if not out_rel:
            continue
        gen_path = (OUTPUT_DIR / out_rel).resolve()

        gt_raw = rec.get("target") or rec.get("gt")
        if not gt_raw:
            n_no_target += 1
            continue
        gt_path = resolve_path_maybe_relative(gt_raw, JSONL_BASE)

        if not (gen_path.exists() and gt_path.exists()):
            n_missing += 1
            continue

        pairs.append((gt_path, gen_path, rec))
        n_usable += 1

print(f"[PAIRING] total_lines={n_total}, usable_pairs={n_usable}, no_target_field={n_no_target}, missing_files={n_missing}")
assert len(pairs) > 0, "No valid (GT, Generated) pairs found."

# %%


[PAIRING] total_lines=516, usable_pairs=516, no_target_field=0, missing_files=0


In [3]:
# ==== 构建 tensor ====
TILE_FIT_MODE   = cfg["tile_fit_mode"]
videos_gt_list  = []
videos_gen_list = []
failed_convert  = 0

for gt_path, gen_path, rec in pairs:
    try:
        v_gt, v_gen = cut_grid_as_video_tilefit(gt_path, gen_path, TILE_FIT_MODE)
        if v_gt.shape != v_gen.shape:
            failed_convert += 1
            continue
        videos_gt_list.append(v_gt)
        videos_gen_list.append(v_gen)
    except Exception:
        failed_convert += 1
        continue

assert len(videos_gt_list) > 0, "No valid pairs after tile-wise fitting."

videos_gt  = torch.stack(videos_gt_list,  dim=0)  # [N,9,3,H,W]
videos_gen = torch.stack(videos_gen_list, dim=0)

N, T, C, H, W = videos_gt.shape
print(f"[DATA] videos_gt={videos_gt.shape}, videos_gen={videos_gen.shape}, failed_convert={failed_convert}, TILE_FIT_MODE={TILE_FIT_MODE}")



[DATA] videos_gt=torch.Size([516, 9, 3, 240, 320]), videos_gen=torch.Size([516, 9, 3, 240, 320]), failed_convert=0, TILE_FIT_MODE=tile_gt_to_gen


In [4]:
# %% [markdown]
# === 低显存评估：分批上GPU + 混合精度 + FVD自动扩帧 ===

# %%
import json as pyjson
import csv, gc
from statistics import mean
from datetime import datetime
from tqdm import tqdm

from calculate_mse   import calculate_mse
from calculate_fvd   import calculate_fvd
from calculate_psnr  import calculate_psnr
from calculate_ssim  import calculate_ssim
from calculate_lpips import calculate_lpips

torch.set_grad_enabled(False)

# ====== 可调参数（降显存关键）======
GPU_DTYPE   = torch.float16           # fp16 可显著降显存；若报精度错误再换 torch.bfloat16
LPIPS_BS    = 1000                       # LPIPS 小批量（按 batch 维拆）
FVD_MODE    = "approx_batch_mean"     # "whole" | "approx_batch_mean"
FVD_BS      = 50                       # 仅在 approx 模式下生效
FVD_MIN_FRM = 10                      # FVD 至少帧数

DEVICE = torch.device(cfg["device"])
ONLY_FINAL = bool(cfg["only_final"])

def _tolist(x):
    if hasattr(x, "tolist"):
        return x.tolist()
    return list(x)

def last_or_none(seq):
    return float(seq[-1]) if len(seq) else None

def mean_or_none(seq):
    return float(mean(seq)) if len(seq) else None

def ensure_min_frames_cpu(videos: torch.Tensor, min_frames=10):
    """在CPU上扩帧（时间维），避免GPU上扩帧占显存。"""
    if videos.size(1) < min_frames:
        t = videos.size(1)
        repeat_factor = (min_frames + t - 1) // t
        v = videos.repeat(1, repeat_factor, 1, 1, 1)
        return v[:, :max(min_frames, t)]
    return videos

In [5]:
# ---------- 1) SSIM/PSNR/MSE 在 CPU ----------
v1_cpu = videos_gt.cpu()
v2_cpu = videos_gen.cpu()

print("[CPU] computing SSIM/PSNR/MSE ...")

psnr_res  = calculate_psnr(v1_cpu, v2_cpu, only_final=ONLY_FINAL)
mse_res   = calculate_mse (v1_cpu, v2_cpu, only_final=ONLY_FINAL)
ssim_res  = calculate_ssim(v1_cpu, v2_cpu, only_final=ONLY_FINAL)




[CPU] computing SSIM/PSNR/MSE ...
calculate_psnr...


100%|██████████| 516/516 [00:01<00:00, 335.98it/s]


calculate_mse...


MSE (per video): 100%|██████████| 516/516 [00:00<00:00, 2309.13it/s]


calculate_ssim...


100%|██████████| 516/516 [07:09<00:00,  1.20it/s]


In [6]:
ssim_series = _tolist(ssim_res["value"])
ssim_std    = _tolist(ssim_res["value_std"])
psnr_series = _tolist(psnr_res["value"])
psnr_std    = _tolist(psnr_res["value_std"])
mse_series  = _tolist(mse_res ["value"])
mse_std     = _tolist(mse_res ["value_std"])

# 立刻释放CPU大张量（若内存紧张）
del v1_cpu, v2_cpu
gc.collect()


300

In [7]:
# ---------- 2) LPIPS 在 GPU 按 batch 计算 ----------
def calc_lpips_batched(v1, v2, bs=4, dtype=torch.float16, device=DEVICE):
    vals, stds = [], []
    N = v1.size(0)
    pbar = tqdm(range(0, N, bs), desc="LPIPS (batched)", leave=False)
    for i in pbar:
        j = min(i+bs, N)
        with torch.autocast(device_type='cuda', dtype=dtype):
            out = calculate_lpips(
                v1[i:j].to(device, dtype=dtype, non_blocking=True),
                v2[i:j].to(device, dtype=dtype, non_blocking=True),
                device, only_final=ONLY_FINAL
            )
        vals.extend(_tolist(out["value"]))
        stds.extend(_tolist(out["value_std"]))
        # 释放子批内存
        del out
        torch.cuda.empty_cache()
    return {"value": vals, "value_std": stds}

print("[GPU] computing LPIPS (batched) ...")
lpips_res = calc_lpips_batched(videos_gt, videos_gen, bs=LPIPS_BS, dtype=GPU_DTYPE, device=DEVICE)
lpips_series = _tolist(lpips_res["value"])
lpips_std    = _tolist(lpips_res["value_std"])





[GPU] computing LPIPS (batched) ...


LPIPS (batched):   0%|          | 0/1 [00:00<?, ?it/s]

calculate_lpips...


100%|██████████| 516/516 [00:15<00:00, 33.49it/s]
                                                              

In [8]:
# ---------- 3) FVD 在 GPU：whole 精确 or approx 分批 ----------
def calc_fvd_whole(v1, v2, dtype=torch.float16, device=DEVICE):
    # 仅一次搬到GPU；先在CPU扩帧，再搬
    v1e = ensure_min_frames_cpu(v1, FVD_MIN_FRM).to(device, dtype=dtype, non_blocking=True)
    v2e = ensure_min_frames_cpu(v2, FVD_MIN_FRM).to(device, dtype=dtype, non_blocking=True)
    with torch.autocast(device_type='cuda', dtype=dtype):
        out = calculate_fvd(v1e, v2e, device, method='styleganv', only_final=ONLY_FINAL)
    del v1e, v2e
    torch.cuda.empty_cache()
    return out

def calc_fvd_approx_batched(v1, v2, bs=4, dtype=torch.float16, device=DEVICE):
    """
    近似：分批算 FVD，再按批数取平均（严格不等价，但省显存）。
    如果你追求严格等价，请改库：先提取全量I3D特征 -> 合并 -> Frechet。
    """
    N = v1.size(0)
    vals = []
    pbar = tqdm(range(0, N, bs), desc="FVD approx (batched)", leave=False)
    for i in pbar:
        j = min(i+bs, N)
        v1e = ensure_min_frames_cpu(v1[i:j], FVD_MIN_FRM)
        v2e = ensure_min_frames_cpu(v2[i:j], FVD_MIN_FRM)
        with torch.autocast(device_type='cuda', dtype=dtype):
            out = calculate_fvd(
                v1e.to(device, dtype=dtype, non_blocking=True),
                v2e.to(device, dtype=dtype, non_blocking=True),
                device, method='styleganv', only_final=ONLY_FINAL
            )
        vals.extend(_tolist(out["value"]))
        del out, v1e, v2e
        torch.cuda.empty_cache()
    # 用“最后一个时间点”的批均值代表总体近似
    return {"value": vals}

print(f"[GPU] computing FVD mode = {FVD_MODE} ...")
if FVD_MODE == "whole":
    fvd_res = calc_fvd_whole(videos_gt, videos_gen, dtype=GPU_DTYPE, device=DEVICE)
else:
    fvd_res = calc_fvd_approx_batched(videos_gt, videos_gen, bs=FVD_BS, dtype=GPU_DTYPE, device=DEVICE)

fvd_series = _tolist(fvd_res["value"])

[GPU] computing FVD mode = approx_batch_mean ...


FVD approx (batched):   0%|          | 0/11 [00:00<?, ?it/s]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:03<00:00,  3.37s/it]
FVD approx (batched):   9%|▉         | 1/11 [00:04<00:45,  4.59s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:05<00:00,  5.37s/it]
FVD approx (batched):  18%|█▊        | 2/11 [00:11<00:51,  5.77s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:01<00:00,  2.00s/it]
FVD approx (batched):  27%|██▋       | 3/11 [00:14<00:36,  4.61s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:06<00:00,  6.52s/it]
FVD approx (batched):  36%|███▋      | 4/11 [00:22<00:40,  5.85s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:01<00:00,  1.96s/it]
FVD approx (batched):  45%|████▌     | 5/11 [00:25<00:29,  4.89s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:03<00:00,  3.59s/it]
FVD approx (batched):  55%|█████▍    | 6/11 [00:30<00:24,  4.87s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:04<00:00,  4.89s/it]
FVD approx (batched):  64%|██████▎   | 7/11 [00:36<00:21,  5.29s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:03<00:00,  3.23s/it]
FVD approx (batched):  73%|███████▎  | 8/11 [00:40<00:15,  5.02s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:05<00:00,  5.03s/it]
FVD approx (batched):  82%|████████▏ | 9/11 [00:47<00:10,  5.42s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:05<00:00,  5.87s/it]
FVD approx (batched):  91%|█████████ | 10/11 [00:54<00:05,  5.93s/it]

calculate_fvd...
Loading model from: /root/PhotoDoodle/fvd/styleganv/i3d_torchscript.pt


100%|██████████| 1/1 [00:03<00:00,  3.47s/it]
                                                                     

In [9]:
# ---------- 4) 汇总/保存 ----------
N, T, C, H, W = videos_gt.shape
summary = {
    "num_pairs": int(N),
    "video_length": int(T),
    "channels": int(C),
    "height": int(H),
    "width": int(W),
    "tile_fit_mode": cfg["tile_fit_mode"],
    "only_final": ONLY_FINAL,
    "gpu_dtype": str(GPU_DTYPE).split(".")[-1],
    "fvd_mode": FVD_MODE,

    # 最后一帧
    "fvd_last":   last_or_none(fvd_series),
    "lpips_last": last_or_none(lpips_series),
    "psnr_last":  last_or_none(psnr_series),
    "ssim_last":  last_or_none(ssim_series),
    "mse_last":   last_or_none(mse_series),

    # 均值
    "fvd_mean":   mean_or_none(fvd_series),
    "lpips_mean": mean_or_none(lpips_series),
    "psnr_mean":  mean_or_none(psnr_series),
    "ssim_mean":  mean_or_none(ssim_series),
    "mse_mean":   mean_or_none(mse_series),
}

full_record = {
    "timestamp": datetime.now().isoformat(timespec="seconds"),
    **summary,
    "series": {
        "fvd": fvd_series,
        "lpips": lpips_series, "lpips_std": lpips_std,
        "psnr": psnr_series,   "psnr_std": psnr_std,
        "ssim": ssim_series,   "ssim_std": ssim_std,
        "mse":  mse_series,    "mse_std":  mse_std,
    }
}

print(pyjson.dumps(summary, indent=2))

json_path = OUTPUT_DIR / f"metrics_fit_{cfg['tile_fit_mode']}_full.json"
with open(json_path, "w", encoding="utf-8") as f:
    f.write(pyjson.dumps(full_record, indent=2))
print(f"[OK] saved JSON: {json_path}")

csv_path = OUTPUT_DIR / f"metrics_fit_{cfg['tile_fit_mode']}_series.csv"
max_len = max(map(len, [fvd_series, lpips_series, psnr_series, ssim_series, mse_series]))
with open(csv_path, "w", newline="", encoding="utf-8") as f:
    import csv
    writer = csv.writer(f)
    writer.writerow(["index","fvd","lpips","lpips_std","psnr","psnr_std","ssim","ssim_std","mse","mse_std"])
    for i in range(max_len):
        row = [
            i,
            fvd_series[i]   if i < len(fvd_series)   else "",
            lpips_series[i] if i < len(lpips_series) else "",
            lpips_std[i]    if i < len(lpips_std)    else "",
            psnr_series[i]  if i < len(psnr_series)  else "",
            psnr_std[i]     if i < len(psnr_std)     else "",
            ssim_series[i]  if i < len(ssim_series)  else "",
            ssim_std[i]     if i < len(ssim_std)     else "",
            mse_series[i]   if i < len(mse_series)   else "",
            mse_std[i]      if i < len(mse_std)      else "",
        ]
        writer.writerow(row)
print(f"[OK] saved CSV: {csv_path}")

# 显存回收
del lpips_res, psnr_res, ssim_res, mse_res, fvd_res
torch.cuda.empty_cache()
gc.collect()


{
  "num_pairs": 516,
  "video_length": 9,
  "channels": 3,
  "height": 240,
  "width": 320,
  "tile_fit_mode": "tile_gt_to_gen",
  "only_final": false,
  "gpu_dtype": "float16",
  "fvd_mode": "approx_batch_mean",
  "fvd_last": 4095.478073094209,
  "lpips_last": 0.9942231984563576,
  "psnr_last": 6.290071680789513,
  "ssim_last": 0.017831154379311338,
  "mse_last": 0.21397601068019867,
  "fvd_mean": 4377.132740365402,
  "lpips_mean": 0.9391541193118262,
  "psnr_mean": 7.18772746832456,
  "ssim_mean": 0.06423100679865956,
  "mse_mean": 0.17802256014611986
}
[OK] saved JSON: /root/PhotoDoodle/inference/outputs_bridgeV2_traj_noop/metrics_fit_tile_gt_to_gen_full.json
[OK] saved CSV: /root/PhotoDoodle/inference/outputs_bridgeV2_traj_noop/metrics_fit_tile_gt_to_gen_series.csv


2114