In [None]:
import torch
from diffusers import AutoencoderKL
from PIL import Image
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import os
import torch, os
from diffusers import AutoencoderKL
import time, math, torch, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import torch.nn.functional as F
from torchmetrics.image import FrechetInceptionDistance as FID, InceptionScore as IS, StructuralSimilarityIndexMeasure as SSIM, PeakSignalNoiseRatio as PSNR
from torchmetrics import MeanSquaredError
from contextlib import nullcontext    

VAE_CHOICES = [
    "black-forest-labs/FLUX.1-dev",
    "stabilityai/stable-diffusion-3.5-large",
    "stabilityai/stable-diffusion-xl-base-1.0",
    "cosmos/Cosmos-Tokenizer-CI16x16",
    "cosmos/Cosmos-Tokenizer-CI8x8"

]
COMPILE_MODEL = False #Generally bugged


In [None]:
from huggingface_hub import snapshot_download
from cosmos_tokenizer.image_lib import ImageTokenizer      # comes from Cosmos-Tokenizer
import os, torch

# -----------------------------------------------------------
# 1 · tiny loader for CI-16×16 (≈64 C, 16×16 grid)
# -----------------------------------------------------------
def load_cosmos_ci16x16(device="cuda",
                        dtype=torch.bfloat16,
                        cache_dir="pretrained_ckpts"):
    model_name   = "Cosmos-Tokenizer-CI8x8"
    ckpt_dir     = os.path.join(cache_dir, model_name)
    
    # ––– one-time download (≈325 MB) ––––––––––––––––––––––––
    if not os.path.exists(ckpt_dir):
        snapshot_download(repo_id=f"nvidia/{model_name}",
                          local_dir=ckpt_dir,
                          local_dir_use_symlinks=False)
    
    enc = ImageTokenizer(checkpoint_enc=f"{ckpt_dir}/encoder.jit") \
              .to(device).to(dtype)
    dec = ImageTokenizer(checkpoint_dec=f"{ckpt_dir}/decoder.jit") \
              .to(device).to(dtype)
    
    # match your original “vae_encode / vae_decode” API
    encode = lambda x: enc.encode(x)[0]      # returns latent tensor
    decode = lambda z: dec.decode(z)         # reconstructs image
    return encode, decode

In [None]:

USE_COSMOS = True          # ⇠ flip this flag

if USE_COSMOS:
    VAE_CHOICE = 4
    vae_encode, vae_decode = load_cosmos_ci16x16()
else:
    VAE_CHOICE = 2
    # ---- your original AutoencoderKL branch ----
    vae = AutoencoderKL.from_pretrained(
        VAE_CHOICES[VAE_CHOICE],
        subfolder="vae",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    if COMPILE_MODEL:
        vae_encode = torch.compile(vae.encode)
        vae_decode = torch.compile(vae.decode)
    else:
        vae_encode, vae_decode = vae.encode, vae.decode


In [None]:
MAX_SIDE_RESOLUTION = 1024

def resize_long_side(pil: Image.Image, max_side: int = 1024) -> Image.Image:
    """
    Rescales an image so that its longer edge becomes `max_side`, preserving aspect ratio,
    but only if the image's longer edge is greater than `max_side`.
    Otherwise, returns the image unchanged.
    """
    w, h = pil.size  # PIL gives (W, H)
    long_side = max(w, h)
    scale = max_side / long_side
    new_w, new_h = int(round(w * scale)), int(round(h * scale))
    return pil.resize((new_w, new_h), Image.LANCZOS)


def pad(t: torch.Tensor, pad_size: int):
    _, h, w = t.shape
    pad_w = (pad_size - w % pad_size) % pad_size
    pad_h = (pad_size - h % pad_size) % pad_size
    return F.pad(t, (0, pad_w, 0, pad_h), "replicate"), h, w

In [None]:
# make sure your VAE is loaded

# ---------------------------------------------------------------------
# 0 · config
# ---------------------------------------------------------------------
src_dir = Path("processed")
vae_choice_str = VAE_CHOICES[VAE_CHOICE].replace("/", "_")
out_dir = Path(f"decoded_images_{vae_choice_str}")
out_dir.mkdir(exist_ok=True)

# ---------------------------------------------------------------------
# 1 · helper: load VAE
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# 2 · metrics
# ---------------------------------------------------------------------
device = "cuda"
ssim_metric = SSIM(data_range=1.0).to(device)
psnr_metric = PSNR(data_range=1.0).to(device)
mse_metric = MeanSquaredError().to(device)

def encode_decode_paths(paths):
    results, stats = {}, []
    for p in paths:
        pil = Image.open(p).convert("RGB")
        # if image size is greater than max_side, resize it
        # Only resize if either side is greater than MAX_SIDE_RESOLUTION
        ref = TF.to_tensor(pil)
        if USE_COSMOS:
            refp, H, W = pad(ref, 16)  # multiple of 8
        else:
            refp, H, W = pad(ref, 8)  # multiple of 8
        refp = refp.unsqueeze(0).to(device)

        with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
            ref_scaled = refp.mul(2).sub(1)
            if USE_COSMOS:
                z = vae_encode(ref_scaled)
                print(z.shape)
            else:
                z = vae_encode(ref_scaled).latent_dist.sample()
            if USE_COSMOS:
                recon = vae_decode(z)
            else:
                recon = vae_decode(z).sample
            recon = recon.to(torch.float32)[:, :, :H, :W].contiguous()

        # quality
        recon_01 = recon[:,:,:H,:W].add(1).div(2)
        ref_01 = refp[:,:, :H,:W]
        ssim = ssim_metric(recon_01, ref_01).item()
        psnr = psnr_metric(recon_01, ref_01).item()
        # --- FIX: use .reshape(-1) instead of .view(-1) in torchmetrics workaround ---
        # This avoids the RuntimeError about non-contiguous tensors.
        # See: https://github.com/Lightning-AI/torchmetrics/issues/1862
        # and error message: "view size is not compatible with input tensor's size and stride..."
        # The error is inside torchmetrics, but we can work around it by passing contiguous tensors.
        # So, ensure recon_01 and ref_01 are contiguous before passing to mse_metric.
        recon_01_contig = recon_01.contiguous()
        ref_01_contig = ref_01.contiguous()
        rmse = mse_metric(recon_01_contig, ref_01_contig).item() ** 0.5

        print(f"{p.name:22s}  SSIM {ssim:.4f}  PSNR {psnr:.2f} dB   RMSE {rmse:.6f}")

        results[p] = recon.cpu()
        stats.append(dict(
            file=p.name, width=W, height=H,
            ssim=ssim, psnr_db=psnr, rmse=rmse,
        ))

        del ref, ref_scaled, z, recon
        torch.cuda.empty_cache()
    return results, stats

# ---------------------------------------------------------------------
# 4 · run the batch
# ---------------------------------------------------------------------
paths = sorted(src_dir.glob("*.[jp][pn]g"))
if not paths:
    raise RuntimeError(f"No PNG/JPG files in {src_dir.resolve()}")

recon_imgs, stats = encode_decode_paths(paths)

# ---------------------------------------------------------------------
# 5 · save & optionally preview
# ---------------------------------------------------------------------
for idx, path in enumerate(paths):
    rec = recon_imgs[path]
    img = ((rec.clamp(-1, 1) + 1) / 2).permute(0, 2, 3, 1).cpu().numpy()[0]
    img8 = (img * 255).round().astype("uint8")
    save_path = out_dir / path.name
    Image.fromarray(img8).save(save_path)
    print("✓ saved", save_path)

# ---------------------------------------------------------------------
# 6 · write metrics CSV
# ---------------------------------------------------------------------
csv_path = out_dir / "metrics.csv"
pd.DataFrame(stats).to_csv(csv_path, index=False)
print("✓ metrics written to", csv_path)


In [None]:
from torchvision import transforms
from PIL import Image
from torchvision.datasets import ImageFolder
import functools
import torch.nn.functional as F

MAX_SIDE_RESOLUTION = 1024

batch_size = {256: 64, 512: 16, 1024: 16}[MAX_SIDE_RESOLUTION]

transform = transforms.Compose([
    transforms.Lambda(functools.partial(resize_long_side, max_side=MAX_SIDE_RESOLUTION)),
    transforms.ToTensor(),                # → [0, 1] float32 CHW
    transforms.Lambda(lambda t: t.mul(2).sub(1)),  # scale → [-1, 1]
])

class ARImageFolder(ImageFolder):
    def __getitem__(self, index):
        image, label = super().__getitem__(index)  # pil passes through transform
        orig_size = image.shape[1:]
        return image, label, orig_size

dataset = ARImageFolder(root='/home/ubuntu/imagenet2012/val', transform=transform)

def collate_fn(batch):
    imgs, labels, orig_hw = zip(*batch)

    # Always pad to fixed resolution (MAX_SIDE_RESOLUTION)
    target_h = MAX_SIDE_RESOLUTION
    target_w = MAX_SIDE_RESOLUTION

    imgs_pad = [
        F.pad(
            t,
            (0, target_w - t.shape[2], 0, target_h - t.shape[1]),
            "replicate"
        )
        for t in imgs
    ]
    imgs_pad = torch.stack(imgs_pad, dim=0)      # B × C × H × W
    labels   = torch.as_tensor(labels, dtype=torch.long)
    return imgs_pad, labels, orig_hw



In [None]:
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True,
    collate_fn=collate_fn      # <- key change
)

In [None]:
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from torchmetrics import MeanSquaredError
#from torchmetrics.image.fid import FrechetInceptionDistance
#from torchmetrics.image.inception import InceptionScore

ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)   # expects [0,1]
psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
mse_metric  = MeanSquaredError().to(device)


#FID is really slow and will get effected hard by gray padding if we do a batch based version, also bad at high resolutions...
#fid_metric  = FrechetInceptionDistance(feature=2048, normalize=True)
#is_metric   = InceptionScore(feature=2048, normalize=True)

In [None]:
compute_metrics = False #If you just want to measure metrics, set this to False

In [None]:
#fetch one example
# Fetch one batch from the loader and display the first sample
from tqdm import tqdm

import torch
import time
import numpy as np

import os
import csv

vae_choice_str = VAE_CHOICES[VAE_CHOICE].replace("/", "_")
out_dir = Path(f"decoded_images_{vae_choice_str}")
out_dir.mkdir(exist_ok=True)

show_batches = 0

batch_idx = 0
device = 'cuda'
ssim_vals = []
psnr_vals = []
rmse_vals = []
total_ms = 0

# Get batch size from loader
try:
    batch_size = loader.batch_size
except AttributeError:
    # fallback: get from first batch
    batch_size = None

# Use tqdm with custom description and dynamic postfix
with tqdm(loader, desc="Evaluating", dynamic_ncols=True) as pbar:
    for imgs_pad, labels, orig_hw in pbar:
        imgs_pad = imgs_pad.to(device, non_blocking=True)  # already [-1,1]

        time_start = time.time()
        with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
            if USE_COSMOS:
                z      = vae_encode(imgs_pad)
            else:
                z      = vae_encode(imgs_pad).latent_dist.sample()
            if batch_idx == 0:
                print(z.shape)
            if USE_COSMOS:
                recon  = vae_decode(z)
            else:
                recon  = vae_decode(z).sample.to(torch.float32)   # still padded
        time_end = time.time()
        total_ms += (time_end - time_start) * 1000

        # --- crop both input & recon to original H,W -------------------
        ref01   = imgs_pad.mul(0.5).add(0.5).clamp(0, 1)#.unsqueeze(0)
        recon01 = recon.mul(0.5).add(0.5).clamp(0, 1)#.unsqueeze(0)
        if compute_metrics:
            for i in range(imgs_pad.shape[0]):
                # Get original size
                orig_h, orig_w = orig_hw[i]
                # Get recon shape
                recon_h, recon_w = recon01.shape[2], recon01.shape[3]
                # For this sample, get the actual recon shape (in case recon is batched)
                recon_i_h = recon01[i].shape[1]
                recon_i_w = recon01[i].shape[2]
                # Compute minimum height and width to avoid shape mismatch
                crop_h = min(orig_h, recon01[i].shape[1])
                crop_w = min(orig_w, recon01[i].shape[2])
                if crop_h != orig_h or crop_w != orig_w:
                    print(f"orig_h: {orig_h}, orig_w: {orig_w}")
                    print(f"recon_h: {recon_h}, recon_w: {recon_w}")
                    print(f"recon_i_h: {recon_i_h}, recon_i_w: {recon_i_w}")
                    print(f"Cropping to {crop_h}x{crop_w}")

                # Crop both tensors to the same minimum size
                ref01_i = ref01[i:i+1, :, :crop_h, :crop_w]
                recon01_i = recon01[i:i+1, :, :crop_h, :crop_w]

                # Ensure tensors are contiguous to avoid .view() issues in torchmetrics
                ref01_i = ref01_i.contiguous()
                recon01_i = recon01_i.contiguous()

                # SSIM / PSNR / RMSE -------------------------------------------
                ssim_vals.append(
                    ssim_metric(recon01_i, ref01_i).item()
                )
                psnr_vals.append(
                    psnr_metric(recon01_i, ref01_i).item()
                )
                rmse_vals.append(
                    (mse_metric(recon01_i, ref01_i).item()) ** 0.5
                )
        batch_idx += 1

        # Calculate running means and ms per image
        n_images = batch_idx * (batch_size if batch_size is not None else imgs_pad.shape[0])
        mean_ssim = float(np.mean(ssim_vals)) if ssim_vals else 0.0
        mean_psnr = float(np.mean(psnr_vals)) if psnr_vals else 0.0
        mean_rmse = float(np.mean(rmse_vals)) if rmse_vals else 0.0
        ms_per_image = float(total_ms / n_images) if n_images > 0 else 0.0

        pbar.set_postfix({
            "SSIM": f"{mean_ssim:.4f}",
            "PSNR": f"{mean_psnr:.2f}",
            "RMSE": f"{mean_rmse:.4f}",
            "ms/img": f"{ms_per_image:.2f}"
        })

# Compute metrics
mean_ssim = float(np.mean(ssim_vals))
mean_psnr = float(np.mean(psnr_vals))
mean_rmse = float(np.mean(rmse_vals))
# Use total number of images for final ms per image
total_images = batch_idx * (batch_size if batch_size is not None else 1)
time_per_image = float(total_ms / total_images) if total_images > 0 else 0.0

# Print the metrics
print(f"SSIM: {mean_ssim}")
print(f"PSNR: {mean_psnr}")
print(f"RMSE: {mean_rmse}")
print(f"Time per image: {time_per_image} ms")

# Save metrics to imagenet_metrics.csv in the respective folder
# Try to infer the output folder from dataset or loader, fallback to current dir
csv_path = os.path.join(out_dir, f"{MAX_SIDE_RESOLUTION}_imagenet_metrics.csv")

csv_fields = ["SSIM", "PSNR", "RMSE", "TimePerImage_ms"]
csv_values = [mean_ssim, mean_psnr, mean_rmse, time_per_image]

# If file exists, append; else, write header
write_header = not os.path.exists(csv_path)
try:
    with open(csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(csv_fields)
        writer.writerow(csv_values)
    print(f"Metrics saved to {csv_path}")
except Exception as e:
    print(f"Could not save metrics to {csv_path}: {e}")
