# Calculate Val Metrics
### PSNR + SSIM

In [None]:
#import statements

import os
import re
import cv2
import torch
import kornia.metrics as km
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict

In [2]:

# ---- CONFIG ----
VIS_DIR = Path("/teamspace/studios/this_studio/neosr/experiments/train_plksr_3x/visualization")  # visualization folder with all frames
GT_DIR = Path("/teamspace/studios/upres-ml-dataset-small/sr_dataset/3_0x/val/HR")          # ground truth images

MODEL_NAME = "plksr"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------

In [3]:
def read_and_tensorize(path):
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0  # [C, H, W]
    return img

def chunked(iterable, size):
    for i in range(0, len(iterable), size):
        yield iterable[i:i + size]

In [4]:
def main():
    # Step 1: collect all SR-GT pairs grouped by checkpoint step
    pairs = []
    for frame_dir in tqdm(list(VIS_DIR.iterdir()), desc="Scanning frame folders"):
        if not frame_dir.is_dir():
            continue

        frame_name = frame_dir.name
        gt_path = GT_DIR / f"{frame_name}.png"
        if not gt_path.exists():
            continue

        for sr_file in frame_dir.glob(f"{frame_name}_*.png"):
            if "_lq.png" in sr_file.name:
                continue

            match = re.search(rf"{frame_name}_(\d+)\.png", sr_file.name)
            if match:
                step = int(match.group(1))
                pairs.append((step, sr_file, gt_path))

    step_to_pairs = defaultdict(list)
    for step, sr_path, gt_path in pairs:
        step_to_pairs[step].append((sr_path, gt_path))

    print(f"\n📊 Checkpoints found: {len(step_to_pairs)}")

   # Step 2: batched PSNR & SSIM on GPU with chunking
    results = []
    batch_size = 8  # you can increase if memory allows

    for step in tqdm(sorted(step_to_pairs.keys()), desc="Evaluating checkpoints"):
        sr_tensors = []
        gt_tensors = []

        for sr_path, gt_path in step_to_pairs[step]:
            try:
                sr_tensor = read_and_tensorize(sr_path)
                gt_tensor = read_and_tensorize(gt_path)

                if sr_tensor.shape != gt_tensor.shape:
                    raise ValueError(f"Shape mismatch: {sr_path.name}")

                sr_tensors.append(sr_tensor)
                gt_tensors.append(gt_tensor)
            except Exception as e:
                print(f"Error reading {sr_path.name}: {e}")
                continue

        if not sr_tensors or not gt_tensors:
            continue

        psnrs, ssims = [], []
        for sr_chunk, gt_chunk in zip(chunked(sr_tensors, batch_size), chunked(gt_tensors, batch_size)):
            try:
                sr_batch = torch.stack(sr_chunk).to(DEVICE)
                gt_batch = torch.stack(gt_chunk).to(DEVICE)

                psnr_val = km.psnr(sr_batch, gt_batch, 1.0).mean().item()
                ssim_val = km.ssim(sr_batch, gt_batch, 11).mean().item()

                psnrs.append(psnr_val)
                ssims.append(ssim_val)
            except Exception as e:
                print(f"Error in chunked PSNR/SSIM for step {step}: {e}")
                continue

        if psnrs and ssims:
            results.append({
                "checkpoint_step": step,
                "avg_psnr": np.mean(psnrs),
                "avg_ssim": np.mean(ssims),
                "num_images": len(sr_tensors)
            })

    # Step 3: save results
    df = pd.DataFrame(results).sort_values("checkpoint_step")
    timestamp = datetime.now().strftime("%Y%m%d")
    output_csv = f"{MODEL_NAME}_val_metrics_{timestamp}_{len(results)}ckpts.csv"
    df.to_csv(output_csv, index=False)
    print(f"\n✅ Saved results to: {output_csv}")


In [None]:
if __name__ == "__main__":
    main()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load your results CSV
df = pd.read_csv("plksr_val_metrics_20250416_49ckpts.csv")  # update name if different

# Sort by checkpoint_step (just in case)
df = df.sort_values("checkpoint_step")

# Plot
plt.figure(figsize=(12, 6))

plt.plot(df["checkpoint_step"], df["avg_psnr"], label="PSNR", marker="o")
plt.plot(df["checkpoint_step"], df["avg_ssim"], label="SSIM", marker="s")

plt.title("PLKSR Validation Metrics Avg by Checkpoint Step")
plt.xlabel("Checkpoint Step")
plt.ylabel("Metric Value")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig("psnr_ssim_plot.png", dpi=300)
plt.show()

In [None]:
plt.savefig("psnr_ssim_plot.png", dpi=300)

In [None]:
import pandas as pd

# Load your PSNR/SSIM CSV
df = pd.read_csv("./plksr_val_metrics_20250416_49ckpts.csv")

# Best PSNR
best_psnr_idx = df["avg_psnr"].idxmax()
best_psnr_step = df.loc[best_psnr_idx, "checkpoint_step"]
best_psnr_value = df.loc[best_psnr_idx, "avg_psnr"]

# Best SSIM
best_ssim_idx = df["avg_ssim"].idxmax()
best_ssim_step = df.loc[best_ssim_idx, "checkpoint_step"]
best_ssim_value = df.loc[best_ssim_idx, "avg_ssim"]

# Report results
print("For PLKSR: ")
print(f"📈 Best PSNR:  {best_psnr_value:.4f} at step {best_psnr_step}")
print(f"📈 Best SSIM:  {best_ssim_value:.4f} at step {best_ssim_step}")


In [None]:
# Load your PSNR/SSIM CSV
df = pd.read_csv("./real_plksr_val_metrics_20250417_42ckpts.csv")

# Best PSNR
best_psnr_idx = df["avg_psnr"].idxmax()
best_psnr_step = df.loc[best_psnr_idx, "checkpoint_step"]
best_psnr_value = df.loc[best_psnr_idx, "avg_psnr"]

# Best SSIM
best_ssim_idx = df["avg_ssim"].idxmax()
best_ssim_step = df.loc[best_ssim_idx, "checkpoint_step"]
best_ssim_value = df.loc[best_ssim_idx, "avg_ssim"]

# Report results
print("For RealPLKSR: ")
print(f"📈 Best PSNR:  {best_psnr_value:.4f} at step {best_psnr_step}")
print(f"📈 Best SSIM:  {best_ssim_value:.4f} at step {best_ssim_step}")