In [None]:
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
from utils import get_config, load_model, PatchImageProcessor, remove_center_padding
from big_vision.models.proj.vitok.naflex_vit_vae import patches_to_image
import jax
from jax2torch import jax2torch
import numpy as np
from PIL import Image
import jax.numpy as jnp
import os

os.environ.setdefault(
    "XLA_FLAGS",
   # "--xla_gpu_enable_async_collectives=true "
    "--xla_gpu_enable_triton_gemm=true "
    "--xla_gpu_use_runtime_fusion=true",
)

# Constants
MAX_SIDE_RESOLUTION = 1024
SAVE_VARIANT = 'S_B/16x64' # 'S_B/16x32' or 'S_B/24x64'
patch_size = int(SAVE_VARIANT.split('/')[1].split('x')[0])
GCS_PATH = {'S_B/16x64':  "gs://vidtok-data/vae_10/S_B_high_res_finetune/params.npz", 'S_B/16x32': "gs://vidtok-data/vae_17/4096_S_B_16_32/params.npz", "S_B/24x64": "gs://vidtok-data/vae_18/1600_S_B_24_64rr/params.npz",
            "S_B/16x16+256_fixed_AR": "gs://vidtok-data/final/S_B_16_fixed_AR_256_params.npz",
            "S_B/16x32+256": "gs://vidtok-data/final/S_B_16_32_256_params.npz",
            "S_B/24x64+1800": "gs://vidtok-data/final/S_B_24_64_1800_params.npz",
            }[SAVE_VARIANT]
MAX_TOKENS = (np.ceil(MAX_SIDE_RESOLUTION / patch_size)) ** 2
MAX_TOKENS = int(np.ceil(MAX_TOKENS))
print(MAX_TOKENS)

In [None]:
import torch.backends.cudnn as cudnn
torch.set_float32_matmul_precision("high")  # TF‑32 matmuls
cudnn.allow_tf32 = True                     # TF‑32 convs
cudnn.benchmark  = True
DTYPE  = torch.bfloat16
DEVICE = torch.device("cuda")

In [None]:
class NaFlexImageFolder(ImageFolder):
    def __init__(self, *args, patch_size=16, max_tokens=1024, **kwargs):
        super().__init__(*args, **kwargs)
        self.patch_processor = PatchImageProcessor(patch_size=patch_size, token_match=True, max_tokens=max_tokens)
    def __getitem__(self, index):
        image, _ = super().__getitem__(index)
        return self.patch_processor.preprocess_pil(image)[:5]

In [None]:
VARIANT =  SAVE_VARIANT.split('+')[0]
config = get_config(f"variant={VARIANT}")
patch_size = config.patch_size
max_grid_size = config.max_grid_size
print(max_grid_size)
model, params = load_model(config, checkpoint_path=GCS_PATH, max_sequence_len=MAX_TOKENS, patch_size=patch_size)

In [None]:
@jax.jit
def recon_apply(batch):
    print(batch[0].shape)
    print(batch[1].shape)
    recon_tuple, _ = model.apply({'params': params}, batch, None)
    recon = patches_to_image(recon_tuple, max_grid_size, max_grid_size, patch_size)
    reference = patches_to_image(batch, max_grid_size, max_grid_size, patch_size)
    return recon, reference

forward_torch = jax2torch(recon_apply)
#forward_torch = torch.cuda._graph_callable(forward_torch)
#forward_torch = torch.compile(forward_torch)

def resize_long_side(pil: Image.Image) -> Image.Image:
    """
    Rescales *any* image (up- or down-sampling) so that its longer edge
    becomes `max_side`, preserving aspect ratio.
    """
    w, h = pil.size                     # PIL gives (W, H)
    scale = MAX_SIDE_RESOLUTION / max(w, h)        # ≥ 1  → upsample,  < 1 → downsample
    new_w, new_h = int(round(w * scale)), int(round(h * scale))
    if (new_w, new_h) == (w, h):        # already correct
        return pil
    return pil.resize((new_w, new_h), Image.LANCZOS)


In [None]:
import torch
from torchvision import transforms
from pathlib import Path
import time
from PIL import Image
from torchmetrics.image import FrechetInceptionDistance as FID, InceptionScore as IS, StructuralSimilarityIndexMeasure as SSIM, PeakSignalNoiseRatio as PSNR
from torchmetrics import MeanSquaredError as MSE
import glob

patch_processor = PatchImageProcessor(patch_size=patch_size)

# ────────── 4. metrics & misc  ──────────────────────────────────────
ssim_metric = SSIM(data_range=1.0).to('cuda')
psnr_metric = PSNR(data_range=1.0).to('cuda')
mse_metric  = MSE().to('cuda')

def to_uint8(chw):                # torch.float in [0,1]
    return (chw.clamp(0,1)*255).byte().cpu().permute(1,2,0).numpy()

# ────────── 5. main loop ───────────────────────────────────────────
SAVE_DIR = Path(f"decoded_images_{SAVE_VARIANT}".replace("/", "_")); SAVE_DIR.mkdir(exist_ok=True)
stats = []

for path in sorted(glob.glob("processed/*.[jp][pn]g")):
    img_name = Path(path).stem
    pil      = Image.open(path).convert("RGB")
    orig_wh  = pil.size
    print(orig_wh)
    orig_hw  = orig_wh[1], orig_wh[0]
    patches, ptype, yidx, xidx, new_orig_hw = patch_processor.preprocess_pil(pil)[:5]
    print(new_orig_hw)
    if new_orig_hw != orig_hw:
        print("Original size mismatch")
        print(orig_hw)
        print(new_orig_hw)
        assert False

    batch = (
        patches.to('cuda').contiguous().unsqueeze(0).repeat(2,1,1), #Need to repeat due to issue with batch in Jax vs PyTorch
        ptype.to('cuda').contiguous().repeat(2,1),
        yidx.to('cuda').contiguous().repeat(2,1),
        xidx.to('cuda').contiguous().repeat(2,1),
    )

    not_gray = (ptype != 0)

    t0 = time.perf_counter()
    recon, ref = forward_torch(batch)      # ⬅ no XlaRuntimeError now
    ms = (time.perf_counter() - t0)*1000

    max_y = torch.where(not_gray, yidx, torch.full_like(yidx, -1)).max().item()
    max_x = torch.where(not_gray, xidx, torch.full_like(xidx, -1)).max().item()
    row_end = (max_y + 1) * patch_size
    col_end = (max_x + 1) * patch_size
    # Crop to valid region first
    ref_img = ref[:, :row_end, :col_end, :][0]
    recon_img = recon[:, :row_end, :col_end, :][0]

    print(row_end, col_end)
    # Then remove center padding to get original size
    ref_final = remove_center_padding(ref_img, (int(orig_hw[0]), int(orig_hw[1])))
    recon_final = remove_center_padding(recon_img, (int(orig_hw[0]), int(orig_hw[1])))

    print(orig_hw)
    print(ref_final.shape, recon_final.shape)
    print(ref_final.min(), ref_final.max(), recon_final.min(), recon_final.max())

    ref_final = ref_final.permute(2, 0, 1).unsqueeze(0).add(1).div(2)
    recon_final = recon_final.permute(2, 0, 1).unsqueeze(0).add(1).div(2)

    print(ref_final.shape, recon_final.shape)

    # Ensure tensors are contiguous before passing to metrics
    ref_final = ref_final.contiguous()
    recon_final = recon_final.contiguous()

    # SSIM / PSNR / RMSE -------------------------------------------
    ssim_val = ssim_metric(recon_final, ref_final).item()
    psnr_val = psnr_metric(recon_final, ref_final).item()
    rmse_val = mse_metric(recon_final, ref_final).item() ** 0.5

    # Convert to uint8 and remove batch dimension for saving
    recon_final = (recon_final * 255).to(torch.uint8)[0]
    ref_final = (ref_final * 255).to(torch.uint8)[0]

    # Convert from CHW to HWC for PIL
    recon_final_np = recon_final.permute(1, 2, 0).cpu().numpy()
    ref_final_np = ref_final.permute(1, 2, 0).cpu().numpy()

    # Save images as PNG
    Image.fromarray(recon_final_np).save(SAVE_DIR/f"{img_name}.png")
    Image.fromarray(ref_final_np).save(SAVE_DIR/f"{img_name}_ref.png")

    print(f"{img_name:20s}  SSIM {ssim_val:.4f}  PSNR {psnr_val:.2f} dB  RMSE {rmse_val:.6f}  {ms:6.1f} ms")

    stats.append(dict(file=Path(path).name,
                      height=orig_hw[0], width=orig_hw[1],
                      ssim=ssim_val, psnr_db=psnr_val, rmse=rmse_val,
                      elapsed_ms=ms))

# ────────── 6. CSV ─────────────────────────────────────────────────
import pandas as pd
pd.DataFrame(stats).to_csv(SAVE_DIR/"metrics.csv", index=False)
print("✓ metrics written:", SAVE_DIR/"metrics.csv")

In [None]:
MAX_SIDE_RESOLUTION = 2048
MAX_TOKENS = (np.ceil(MAX_SIDE_RESOLUTION / patch_size)) ** 2
MAX_TOKENS = int(np.ceil(MAX_TOKENS))
print(MAX_TOKENS)
batch_size = {256: 512, 512: 128, 1024: 32, 2048: 1}[MAX_SIDE_RESOLUTION]
transform = transforms.Compose([
    transforms.Lambda(resize_long_side),
])
DATASET = 'div8k' #'imagenet2012' or 'div8k'
if DATASET == 'imagenet2012':
    path = '/home/ubuntu/imagenet2012/val'
elif DATASET == 'div8k':
    path = '/home/ubuntu/datasets/div8k/TestSets/test8k'
dataset = NaFlexImageFolder(root=path, transform=transform, patch_size=patch_size, max_tokens=MAX_TOKENS) #Pad an extra token in case
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from torchmetrics import MeanSquaredError
#from torchmetrics.image.fid import FrechetInceptionDistance
#from torchmetrics.image.inception import InceptionScore
device = 'cuda'

ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)   # expects [0,1]
psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
mse_metric  = MeanSquaredError().to(device)


#FID is really slow and will get effected hard by gray padding if we do a batch based version, also bad at high resolutions...
#fid_metric  = FrechetInceptionDistance(feature=2048, normalize=True)
#is_metric   = InceptionScore(feature=2048, normalize=True)

In [None]:
compute_metrics = True #If you just want to measure metrics, set this to False. The time measurement gets messed up otherwise

In [None]:
#fetch one example
# Fetch one batch from the loader and display the first sample
from tqdm import tqdm

import torch
import time
import numpy as np
from pathlib import Path
import os
import csv

out_dir = Path(f"decoded_images_{SAVE_VARIANT}".replace("/", "_"))
out_dir.mkdir(exist_ok=True)

device = 'cuda'
ssim_vals = []
psnr_vals = []
rmse_vals = []
total_ms = 0
num_examples = 0

# Create tqdm object manually so we can update the description
pbar = tqdm(loader, desc='Eval')

for patches, ptype, yidx, xidx, orig_hw in pbar:
    patches = patches.to('cuda')
    ptype = ptype.to('cuda')
    yidx = yidx.to('cuda')
    xidx = xidx.to('cuda')

    time_start = time.time()
    with torch.inference_mode(), torch.amp.autocast(dtype=torch.bfloat16, device_type='cuda'):
        if batch_size == 1:
            batch = (patches.to(torch.bfloat16)[0].contiguous(), ptype[0].contiguous(), yidx[0].contiguous(), xidx[0].contiguous())
        recon, ref = forward_torch(batch)
    time_end = time.time()
    total_ms += (time_end - time_start) * 1000
    if compute_metrics:
        for i in range(patches.shape[0]):
            #crop to original size
            not_gray = (ptype[i] != 0)
            max_y = torch.where(not_gray, yidx[i], torch.full_like(yidx[i], -1)).max().item()
            max_x = torch.where(not_gray, xidx[i], torch.full_like(xidx[i], -1)).max().item()
            row_end = (max_y + 1) * patch_size
            col_end = (max_x + 1) * patch_size
            # Crop to valid region first
            ref_img = ref[i][:row_end, :col_end, :]
            recon_img = recon[i][:row_end, :col_end, :]
            # Then remove center padding to get original size
            ref_final = remove_center_padding(ref_img, (int(orig_hw[0][i]), int(orig_hw[1][i])))
            recon_final = remove_center_padding(recon_img, (int(orig_hw[0][i]), int(orig_hw[1][i])))

            ref_final = ref_final.permute(2, 0, 1).unsqueeze(0).add(1).div(2).to(torch.float32)
            recon_final = recon_final.permute(2, 0, 1).unsqueeze(0).add(1).div(2).to(torch.float32)

            # Ensure tensors are contiguous before passing to metrics
            ref_final = ref_final.contiguous()
            recon_final = recon_final.contiguous()

            # SSIM / PSNR / RMSE -------------------------------------------
            ssim_val = ssim_metric(recon_final, ref_final).item()
            psnr_val = psnr_metric(recon_final, ref_final).item()
            rmse_val = (mse_metric(recon_final, ref_final).item()) ** 0.5

            #save first 4 images for first batch
            if i < 4 and num_examples == 0:
                recon_final_np = recon_final[0].permute(1, 2, 0).cpu().numpy()
                recon_final_np = (recon_final_np * 255).astype(np.uint8)
                ref_final_np = ref_final[0].permute(1, 2, 0).cpu().numpy()
                ref_final_np = (ref_final_np * 255).astype(np.uint8)
                Image.fromarray(recon_final_np).save(SAVE_DIR/f"recon_{i}.png")
                Image.fromarray(ref_final_np).save(SAVE_DIR/f"ref_{i}.png")

            # Check for NaN and handle
            if np.isnan(ssim_val):
                print(f"NaN detected in SSIM for sample {i}, setting to 0.0")
                ssim_val = 0.0
            if np.isnan(psnr_val):
                print(f"NaN detected in PSNR for sample {i}, setting to 0.0")
                psnr_val = 0.0
            if np.isnan(rmse_val):
                print(f"NaN detected in RMSE for sample {i}, setting to 0.0")
                rmse_val = 0.0

            ssim_vals.append(ssim_val)
            psnr_vals.append(psnr_val)
            rmse_vals.append(rmse_val)
    else:
        for i in range(patches.shape[0]):
            ssim_vals.append(0.0)
            psnr_vals.append(0.0)
            rmse_vals.append(0.0)
    
    num_examples += patches.shape[0]

    # Update tqdm with current mean metrics
    if len(ssim_vals) > 0:
        mean_ssim = float(np.mean(ssim_vals))
        mean_psnr = float(np.mean(psnr_vals))
        mean_rmse = float(np.mean(rmse_vals))
        time_per_image = float(total_ms / num_examples)
        pbar.set_postfix({
            "SSIM": f"{mean_ssim:.4f}",
            "PSNR": f"{mean_psnr:.2f}",
            "RMSE": f"{mean_rmse:.4f}",
            "TimePerImage_ms": f"{time_per_image:.2f}"
        })

# Compute metrics
mean_ssim = float(np.mean(ssim_vals))
mean_psnr = float(np.mean(psnr_vals))
mean_rmse = float(np.mean(rmse_vals))
time_per_image = float(total_ms / len(loader))

# Print the metrics
print(f"SSIM: {mean_ssim}")
print(f"PSNR: {mean_psnr}")
print(f"RMSE: {mean_rmse}")
print(f"Time per image: {time_per_image} ms")

# Save metrics to imagenet_metrics.csv in the respective folder
# Try to infer the output folder from dataset or loader, fallback to current dir
csv_path = os.path.join(out_dir, f"{MAX_SIDE_RESOLUTION}_{DATASET}_metrics.csv")

csv_fields = ["SSIM", "PSNR", "RMSE", "TimePerImage_ms"]
csv_values = [mean_ssim, mean_psnr, mean_rmse, time_per_image]

# If file exists, append; else, write header
write_header = not os.path.exists(csv_path)
try:
    with open(csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(csv_fields)
        writer.writerow(csv_values)
    print(f"Metrics saved to {csv_path}")
except Exception as e:
    print(f"Could not save metrics to {csv_path}: {e}")


In [None]:
"""
VAE FIGURE GENERATOR – RMSE fixed‑v‑native bar charts (per‑run aggregation)
NeurIPS‑ready (v2: bigger, clearer, no clipping)
──────────────────────────────────────────────────────────────────────────────
* Enlarged base font, thicker bars, higher DPI, wider figure.
* Tick labels wrapped onto two lines to avoid clipping.
* Legend anchored below plot; no in‑plot title.
* All label‑swap logic and data handling remain unchanged.
"""

# ── 0 · Global style (LaTeX OFF) ────────────────────────────────
import matplotlib as mpl, matplotlib.pyplot as plt
BASE_FONTSIZE = 16  # bigger for camera‑ready
mpl.rcParams.update({
    "text.usetex": False,
    "font.family": "sans-serif",
    "font.size": BASE_FONTSIZE,
    "axes.labelsize": BASE_FONTSIZE + 1,
    "axes.titlesize": BASE_FONTSIZE + 1,
    "xtick.labelsize": BASE_FONTSIZE - 1,
    "ytick.labelsize": BASE_FONTSIZE - 1,
    "legend.fontsize": BASE_FONTSIZE - 1,
    "figure.dpi": 180,            # higher resolution
    "figure.constrained_layout.use": True,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.linewidth": 0.7,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "grid.alpha": 0.35,
})

SHOW_FIGS = True  # toggle windows when running interactively

# ── 1 · Configuration ──────────────────────────────────────────
from pathlib import Path
import os, re, numpy as np, pandas as pd

ROOT = Path(".")  # directory with decoded_images_* folders

RUNS = [
    "S_B_16x32+256_fixedAR",
    "S_B_16x32+256",
    "S_B_16x32",
    "S_B_16x64",
    "Cosmos-Tokenizer-CI8x8",
]

RUN_LABELS = {
    "S_B_16x32+256_fixedAR": "256p\nFixed 32c",
    "S_B_16x32+256": "256‑Tok\n32c",
    "S_B_16x32": "4k‑Tok\n32c",
    "S_B_16x64": "4k‑Tok\n64c",
    "Cosmos-Tokenizer-CI8x8": "Cosmos 8x8",
}

CSV_NAME = "metrics.csv"

STYLE_COLOURS = {"box": "#1f77b4", "ar": "#ff7f0e"}
STYLE_ORDER = ["box", "ar"]
STYLE_LABELS = {"box": "Fixed 1:1", "ar": "Native AR"}

OUT_DIR = Path("figs"); OUT_DIR.mkdir(exist_ok=True)

# ── 2 · Helper functions ──────────────────────────────────────
_res_pat = re.compile(r"_(256|512|1024)(?:\D|$)")

# ── 2 · Helper functions  ───────────────────────────────────────
METRICS      = ["rmse", "ssim", "psnr_db"]               # moved up so helpers can see it
METRIC_LABEL = {"rmse": "RMSE (↓)",
                "ssim": "SSIM (↑)",
                "psnr_db": "PSNR (↑)"}

_res_pat = re.compile(r"_(256|512|1024)(?:\D|$)")

def _parse_res_style(fname: str):
    base = os.path.basename(fname)
    m = _res_pat.search(base)
    res = int(m.group(1)) if m else 1024
    style = "ar" if "_ar_" in base or base.endswith("_ar.png") else "box"
    if m is None and "_ar_" not in base:
        style = "ar"             # default 1024-AR
    return res, style


def load_means(run_tag: str):
    """
    Return {(res, style, metric) -> mean value} for *all* metrics in METRICS.
    """
    csv_path = ROOT / f"decoded_images_{run_tag}" / CSV_NAME
    print(csv_path)
    if not csv_path.exists():
        raise FileNotFoundError(csv_path)

    df = pd.read_csv(csv_path)
    print(df)
    df[["res", "style"]] = df["file"].apply(
        lambda s: pd.Series(_parse_res_style(s))
    )

    # Swap labels if this is the mis-labelled run
    if "fixedAR" in run_tag:
        df["style"] = df["style"].map({"box": "ar", "ar": "box"})

    print(run_tag)

    # group means for every metric we care about
    grouped = (
        df.groupby(["res", "style"])[METRICS]
          .mean()                               # -> MultiIndex rows, columns = metrics
    )

    # flatten to {(res, style, metric): value}
    flat = {}
    for (res, sty), row in grouped.iterrows():
        print(row)
        for metric in METRICS:
            flat[(res, sty, metric)] = row[metric]
    return flat


# ── 3 · Pre-compute per-run means  ──────────────────────────────
run_means = {run: load_means(run) for run in RUNS}

print(run_means)


FIGSIZE = (11.0, 5.8)
BAR_W   = 0.18
GROUP_OFF = (np.arange(len(STYLE_SEQ)) - 1.5) * BAR_W   # offsets −0.27 … +0.27
x = np.arange(len(RUNS))

for metric in METRICS:
    fig, ax = plt.subplots(figsize=FIGSIZE)

    for idx, (sty, res) in enumerate(STYLE_SEQ):
        # fetch values (NaN if missing)
        vals = [
            run_means[run].get((res, sty, metric), np.nan)
            for run in RUNS
        ]
        ax.bar(
            x + GROUP_OFF[idx],
            [0 if np.isnan(v) else v for v in vals],
            width=BAR_W,
            color=STYLE_COLOURS[sty],
            alpha=1.0 if res == 1024 else 0.55,
            label=f"{STYLE_LABELS[sty]} {res}p",
        )

        # per-bar annotation
        for i, v in enumerate(vals):
            if metric == "psnr_db":
                txt = "–" if np.isnan(v) else f"{v:.1f}"
            elif metric == "ssim":
                txt = "–" if np.isnan(v) else f"{v:.2f}"
            else: #remove initial 0. 
                txt = "–" if np.isnan(v) else f"{v:.3f}"
                txt = txt[1:]
            ax.text(
                i + GROUP_OFF[idx],
                (v if not np.isnan(v) else 0) + 0.003,
                txt,
                ha="center", va="bottom",
                fontsize=BASE_FONTSIZE - 5,
            )

    ax.set_xticks(x, [RUN_LABELS[r] for r in RUNS])
    ax.set_ylabel(METRIC_LABEL[metric])
    ax.grid(axis="y", linestyle=":", linewidth=0.4)
    ax.set_ylim(bottom=0)

    ax.legend(
        frameon=False, ncol=4, loc="upper center",
        bbox_to_anchor=(0.5, -0.18), columnspacing=1.8,
    )

    fname = OUT_DIR / f"{metric}_fixed_vs_native_clustered.pdf"
    fig.savefig(fname, format="pdf", dpi=300, bbox_inches="tight")
    print("✓ Saved", fname)
    plt.show()
    plt.close(fig)

print("All done – PDFs are in", OUT_DIR.resolve())


In [None]:
# ================================================================
#   VAE FIGURE GENERATOR  –  drag-and-drop PDFs for Overleaf
# ================================================================
# • Mean-only bars, dense bars
# • Full-image strip
# • Full-image strip with centre-crop inset
# • 10 random 384×384 crops (×4 zoom)  + concatenated contact-sheet
# • Single centre 384×384 crop (×4 zoom)
# ------------------------------------------------
#  1.  Edit RUNS / RUN_LABELS / METRICS / ROOT if needed
#  2.  Run the script (python or notebook cell)
#  3.  All PDFs appear in ./figs/
# ================================================================


# ── 0 · Global style (LaTeX OFF) ────────────────────────────────
import matplotlib as mpl, matplotlib.pyplot as plt
mpl.rcParams.update({
    "text.usetex":       False,
    "font.size":         9,
    "axes.labelsize":    9,
    "axes.titlesize":    9,
    "figure.dpi":        110,
    "figure.constrained_layout.use": True,
    "axes.spines.top":   False,
    "axes.spines.right": False,
    "axes.linewidth":    0.6,
    "xtick.direction":   "out",
    "ytick.direction":   "out",
    "grid.alpha":        0.3,
})

# ── 1 · Configuration ──────────────────────────────────────────
from pathlib import Path
import os, re, random, numpy as np, pandas as pd
from PIL import Image
from functools import reduce

ROOT   = Path(".")    # directory with decoded_images_* folders

#Cosmos-Tokenizer-CI8x8

RUNS = [
    "Cosmos-Tokenizer-CI8x8",
    "S_B_16x64",
]

RUN_LABELS = {
    "Cosmos-Tokenizer-CI8x8":             "Patch Size 8, Channel 8",
    "S_B_16x64":         "Patch Size 16, Channel 32",
}

CSV_NAME  = "metrics.csv"
IMG_NAME  = "owl_eye_1024.png"

IMG_STRIP_H  = 3.8      # inches – height of full-image strip
CROP_PX      = 384
N_CROPS      = 10
ZOOM_FACTOR  = 4
SEED_START   = 1        # seeds = 1 … 10

PALETTE = ["#1f77b4", "#ff7f0e", "#2ca02c"]
IMG_DIR = Path("processed")
OUT_DIR = Path("figs"); OUT_DIR.mkdir(exist_ok=True)

# ── 2 · Locate runs & load metrics ─────────────────────────────
pat  = re.compile(r"decoded_images_(.+)")
runs = {m.group(1): ROOT / d for d in os.listdir(ROOT)
        if (m := pat.match(d)) and (ROOT / d).is_dir()}
missing = set(RUNS) - runs.keys()
if missing:
    raise RuntimeError(f"Missing decoded_images_* folders for: {missing}")
print("▶ Using runs:", ", ".join(RUN_LABELS[t] for t in RUNS))

dfs = []
for tag in RUNS:
    df = pd.read_csv(runs[tag] / CSV_NAME).add_suffix(f"_{tag}")
    df = df.rename(columns={f"file_{tag}": "file"})
    dfs.append(df)
df = reduce(lambda a, b: pd.merge(a, b, on="file", how="inner"), dfs)

shape_cols = [c for c in df.columns if c.startswith(("height_", "width_"))]
df = df[df[shape_cols].nunique(axis=1).eq(1)].reset_index(drop=True)
print(f"✓ {len(df)} images with matching resolution")

def crop_strip(x0, y0, crop_px, zoom, title, height_in=4.2):
    fig, axes = plt.subplots(1, len(RUNS),
                             figsize=(len(RUNS) * 4.0, height_in))
    for ax, tag, col in zip(axes, RUNS, PALETTE):
        crop = Image.open(runs[tag] / IMG_NAME).crop(
            (x0, y0, x0 + crop_px, y0 + crop_px))
        crop = crop.resize((crop_px * zoom, crop_px * zoom),
                           resample=Image.NEAREST)
        ax.imshow(crop)
        ax.set_axis_off()
        ax.set_title(RUN_LABELS[tag], color=col, fontsize=9)
    fig.suptitle(title, y=0.995, fontsize=9)
    return fig

def random_crop_strip(seed, crop_px=CROP_PX, zoom=ZOOM_FACTOR):
    random.seed(seed)
    ref = Image.open(runs[RUNS[0]] / IMG_NAME)
    W, H = ref.size
    x0, y0 = random.randint(0, W - crop_px), random.randint(0, H - crop_px)
    title = f"Random {crop_px}×{crop_px} crop ×{zoom} (seed={seed})"
    return crop_strip(x0, y0, crop_px, zoom, title)

def center_crop_strip(crop_px=CROP_PX, zoom=ZOOM_FACTOR):
    ref = Image.open(runs[RUNS[0]] / IMG_NAME)
    W, H = ref.size
    x0, y0 = (W - crop_px) // 2, (H - crop_px) // 2
    title = f"Centre {crop_px}×{crop_px} crop ×{zoom}"
    return crop_strip(x0, y0, crop_px, zoom, title)


In [None]:
from matplotlib.patches import Rectangle
import random
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

DISPLAY_PX  = 512
RNG         = random.Random()

def save_fixed_tile_and_full(crop_px=CROP_PX, seed: int | None = None):
    """
    • Draw ONE random crop (seed-controlled) on the reference image.
    • Re-use the *exact same* (x0, y0) for every run.
    • Produces <tag>_full.png  (with red rectangle)
              + <tag>_tile.png  (scaled 384 × 384 tile).
    """
    if seed is not None:
        RNG.seed(seed)

    # ── pick (x0, y0) ONCE from reference ───────────────────────────
    ref_img = Image.open(IMG_DIR / IMG_NAME)
    W, H    = ref_img.size
    if W >= 1024 or H >= 1024:
        x0      = RNG.randint(0, W - crop_px)
        y0      = RNG.randint(0, H - crop_px)
    else:
        x0 = 0
        y0 = 0
        crop_px = 0

    def save_full(img, out_path):
        fig, ax = plt.subplots(figsize=(4.2, 3.2))
        ax.imshow(img); ax.set_axis_off()
        if W >= 1024 or H >= 1024:
            ax.add_patch(Rectangle((x0, y0), crop_px, crop_px,
                                edgecolor="red", linewidth=1.2, facecolor="none"))
        fig.savefig(out_path, dpi=300, bbox_inches="tight",
                    pad_inches=0, transparent=True)
        plt.close(fig)

    def save_tile(img, out_path):
        crop = img.crop((x0, y0, x0 + crop_px, y0 + crop_px))
        crop = crop.resize((DISPLAY_PX, DISPLAY_PX), Image.NEAREST)
        crop.save(out_path, format="png")

    # ── reference outputs ───────────────────────────────────────────
    full_out = OUT_DIR / f"{IMG_NAME.replace('.png','')}_full.png"
    tile_out = OUT_DIR / f"{IMG_NAME.replace('.png','')}_tile.png"
    save_full(ref_img, full_out)
    save_tile(ref_img, tile_out)
    print("  ↳ saved", full_out.name, tile_out.name)

    # ── same crop for every run ─────────────────────────────────────
    for tag in RUNS:
        img       = Image.open(runs[tag] / IMG_NAME)
        full_out  = OUT_DIR / f"{tag}_{IMG_NAME.replace('.png','')}_full.png"
        #only do if image is > 1024
        save_full(img, full_out)
        if img.size[0] >= 1024 or img.size[1] >= 1024:
            tile_out  = OUT_DIR / f"{tag}_{IMG_NAME.replace('.png','')}_tile.png"
            
            save_tile(img, tile_out)
        print("  ↳ saved", full_out.name, tile_out.name)

# ── call once near the end ─────────────────────────────────────────
# Example: save_fixed_tile_and_full(seed=42)  # fixed RNG for full reproducibility
save_fixed_tile_and_full()
