In [None]:
# %% [markdown]
# # 3×3 Grid 逐块适配评估（Tile-wise Fit, No Global Resize）
# - Pairing: results.jsonl 里 ok==True 的样本，使用 output 定位生成图，target/gt 定位 GT 图
# - Cutting: 先切 3×3，再逐块适配，顺序 123 / 654 / 789
# - Tile Fit Modes:
#   * "tile_strict":    每块尺寸必须一致，否则跳过
#   * "tile_gen_to_gt": 生成图每块 -> 适配到 GT 对应块尺寸
#   * "tile_gt_to_gen": GT 每块     -> 适配到 生成 对应块尺寸
# - Metrics:
#   * GPU: FVD, LPIPS
#   * CPU: SSIM, PSNR, MSE
# - Output: 扁平 JSON（含 mse），位于 output_dir

# %%
import json
from pathlib import Path
from typing import List, Tuple
from datetime import datetime

import numpy as np
from PIL import Image
import torch

# ==== 配置 ====
CONFIG = {
    "jsonl_path": "/root/PhotoDoodle/data/bridge_test/text_test_index.jsonl",  # 原始输入 JSONL（含 target/gt 相对路径）
    "output_dir": "inference/outputs_bridgeV2_text",                            # 推理输出目录（含 results.jsonl）
    "tile_fit_mode": "tile_gen_to_gt",  # "tile_strict" | "tile_gen_to_gt" | "tile_gt_to_gen"
    "device": "cpu" if torch.cuda.is_available() else "cpu",
    "only_final": False,                # True: 指标只比较最后一帧
}

# ==== 你已有的指标脚本（需可导入） ====
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips
from calculate_mse import calculate_mse


In [None]:
# %% [markdown]
# ## 工具函数：路径解析、切 3×3（蛇形顺序）、逐块适配

# %%
def resolve_path_maybe_relative(p: str, base: Path) -> Path:
    """相对路径基于 base；绝对路径原样返回。"""
    pp = Path(p)
    return pp if pp.is_absolute() else (base / pp).resolve()

def _grid_splits(length: int) -> List[int]:
    """
    用 linspace 生成分割边界（0..length），四个边界 -> 三段。
    即使 length 不是 3 的倍数，也无缝覆盖且不重叠。
    """
    edges = np.linspace(0, length, 4)
    return [int(round(x)) for x in edges]

def _cut_3x3_tiles(im: Image.Image) -> List[Image.Image]:
    """按 3×3 切成 9 个 PIL.Image，蛇形顺序：123 / 654 / 789。"""
    W, H = im.size
    xs = _grid_splits(W)  # [x0,x1,x2,x3]
    ys = _grid_splits(H)  # [y0,y1,y2,y3]
    # 线性顺序的 9 个 box（行优先）
    boxes_linear = [(xs[c], ys[r], xs[c+1], ys[r+1]) for r in range(3) for c in range(3)]
    # 蛇形顺序映射
    order = [0, 1, 2, 5, 4, 3, 6, 7, 8]
    return [im.crop(boxes_linear[idx]) for idx in order]

def _tile_sizes(tiles: List[Image.Image]) -> List[Tuple[int,int]]:
    return [t.size for t in tiles]  # (W,H)

def _resize_tiles_to(tiles: List[Image.Image], target_sizes: List[Tuple[int,int]]) -> List[Image.Image]:
    """将 tiles 逐块 resize 到 target_sizes（(W,H) 列表）。"""
    out = []
    for t, (tw, th) in zip(tiles, target_sizes):
        out.append(t if t.size == (tw, th) else t.resize((tw, th), Image.BICUBIC))
    return out

def _tiles_to_tensor(tiles: List[Image.Image]) -> torch.Tensor:
    """9 个 RGB PIL tile -> [T=9, C=3, H, W]，像素∈[0,1]。"""
    frames = []
    for t in tiles:
        arr = np.asarray(t, dtype=np.uint8)
        ten = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
        frames.append(ten)
    return torch.stack(frames, dim=0)

def cut_grid_as_video_tilefit(gt_path: Path, gen_path: Path, tile_fit_mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    返回 (v_gt, v_gen)：
      1) 各自切 3×3 -> 9 tiles（蛇形）
      2) 逐块适配（避免整图 resize 造成跨格模糊）
      3) 各自拼成 [9,3,H,W]
    """
    gt_im  = Image.open(gt_path).convert("RGB")
    gen_im = Image.open(gen_path).convert("RGB")

    gt_tiles  = _cut_3x3_tiles(gt_im)    # [9]
    gen_tiles = _cut_3x3_tiles(gen_im)   # [9]

    if tile_fit_mode == "tile_strict":
        if _tile_sizes(gt_tiles) != _tile_sizes(gen_tiles):
            gt_im.close(); gen_im.close()
            raise AssertionError(f"[tile_strict] per-tile sizes mismatch: GT={_tile_sizes(gt_tiles)} vs GEN={_tile_sizes(gen_tiles)}")

    elif tile_fit_mode == "tile_gen_to_gt":
        target_sizes = _tile_sizes(gt_tiles)
        gen_tiles = _resize_tiles_to(gen_tiles, target_sizes)

    elif tile_fit_mode == "tile_gt_to_gen":
        target_sizes = _tile_sizes(gen_tiles)
        gt_tiles = _resize_tiles_to(gt_tiles, target_sizes)

    else:
        gt_im.close(); gen_im.close()
        raise ValueError(f"Unknown tile_fit_mode: {tile_fit_mode}")

    v_gt  = _tiles_to_tensor(gt_tiles)    # [9,3,H,W]
    v_gen = _tiles_to_tensor(gen_tiles)   # [9,3,H,W]

    gt_im.close(); gen_im.close()
    return v_gt, v_gen


In [None]:
# %% [markdown]
# ## Pairing：从 results.jsonl 配对 GT 与生成图

# %%
cfg = CONFIG
JSONL_PATH   = Path(cfg["jsonl_path"]).resolve()
JSONL_BASE   = JSONL_PATH.parent
OUTPUT_DIR   = Path(cfg["output_dir"]).resolve()
RESULTS_PATH = OUTPUT_DIR / "results.jsonl"

pairs: List[Tuple[Path, Path, dict]] = []
n_total = n_usable = n_no_target = n_missing = 0

assert RESULTS_PATH.exists(), f"results.jsonl not found: {RESULTS_PATH}"
with open(RESULTS_PATH, "r", encoding="utf-8") as fr:
    for line in fr:
        n_total += 1
        try:
            rec = json.loads(line)
        except Exception:
            continue
        if not rec.get("ok", False):
            continue

        out_rel = rec.get("output")
        if not out_rel:
            continue
        gen_path = (OUTPUT_DIR / out_rel).resolve()

        gt_raw = rec.get("target") or rec.get("gt")
        if not gt_raw:
            n_no_target += 1
            continue
        gt_path = resolve_path_maybe_relative(gt_raw, JSONL_BASE)

        if not (gen_path.exists() and gt_path.exists()):
            n_missing += 1
            continue

        pairs.append((gt_path, gen_path, rec))
        n_usable += 1

print(f"[PAIRING] total_lines={n_total}, usable_pairs={n_usable}, "
      f"no_target_field={n_no_target}, missing_files={n_missing}")
assert len(pairs) > 0, "No valid (GT, Generated) pairs found."


In [None]:
# %% [markdown]
# ## 构建 [N, 9, 3, H, W] 张量（逐块适配）

# %%
TILE_FIT_MODE = cfg["tile_fit_mode"]
videos_gt_list, videos_gen_list = [], []
failed_convert = 0

for gt_path, gen_path, rec in pairs:
    try:
        v_gt, v_gen = cut_grid_as_video_tilefit(gt_path, gen_path, TILE_FIT_MODE)
        # 能 stack：最终两者形状需要一致
        if v_gt.shape != v_gen.shape:
            failed_convert += 1
            continue
        videos_gt_list.append(v_gt)
        videos_gen_list.append(v_gen)
    except Exception:
        failed_convert += 1
        continue

assert len(videos_gt_list) > 0, "No valid pairs after tile-wise fitting."

videos_gt  = torch.stack(videos_gt_list,  dim=0)  # [N,9,3,H,W]
videos_gen = torch.stack(videos_gen_list, dim=0)  # [N,9,3,H,W]
N, T, C, H, W = videos_gt.shape
print(f"[DATA] videos_gt={videos_gt.shape}, videos_gen={videos_gen.shape}, "
      f"failed_convert={failed_convert}, TILE_FIT_MODE={TILE_FIT_MODE}")


In [None]:
# %% [markdown]
# ## 计算指标：GPU(FVD/LPIPS) + CPU(SSIM/PSNR/MSE)，并扁平化保存

# %%
DEVICE = torch.device(cfg["device"])
ONLY_FINAL = bool(cfg["only_final"])

# GPU（FVD/LPIPS）
v1_gpu = videos_gt.to(DEVICE, non_blocking=True)
v2_gpu = videos_gen.to(DEVICE, non_blocking=True)

# CPU（SSIM/PSNR/MSE —— 兼容 numpy）
v1_cpu = videos_gt.cpu()
v2_cpu = videos_gen.cpu()

metrics = {}
# GPU-heavy
metrics["fvd"]   = float(calculate_fvd(v1_gpu, v2_gpu, DEVICE, method='styleganv', only_final=ONLY_FINAL))
metrics["lpips"] = float(calculate_lpips(v1_gpu, v2_gpu, DEVICE, only_final=ONLY_FINAL))
# CPU / numpy
metrics["ssim"]  = float(calculate_ssim(v1_cpu, v2_cpu, only_final=ONLY_FINAL))
metrics["psnr"]  = float(calculate_psnr(v1_cpu, v2_cpu, only_final=ONLY_FINAL))
metrics["mse"]   = float(calculate_mse(v1_cpu, v2_cpu, only_final=ONLY_FINAL))

flat_record = {
    "timestamp": datetime.now().isoformat(timespec="seconds"),
    "tile_fit_mode": TILE_FIT_MODE,
    "only_final": ONLY_FINAL,
    "num_pairs": int(N),
    "video_length": int(T),
    "channels": int(C),
    "height": int(H),
    "width": int(W),
    # 扁平指标
    **metrics,
}

print(json.dumps(flat_record, indent=2))

out_path = OUTPUT_DIR / f"metrics_fit_{TILE_FIT_MODE}_flat.json"
with open(out_path, "w", encoding="utf-8") as f:
    f.write(json.dumps(flat_record, indent=2))
print(f"[OK] flattened metrics saved to: {out_path}")
