In [12]:
# import libraries and config files

import os
import cv2
from PIL import Image
import io
import os
import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers import AutoencoderKL
import numpy as np
from models.vae import encode_img, decode_img

from utils.config import SIMULATED_FOLDER,EXPERIMENTAL_FOLDER, FILES_FIG1A
from utils.preprocess import preprocess_simulation_output_data, preprocess_experimental_backgroundwhite_rawfiles
from utils.preprocess import scale_latents
from utils.display import display_images_with_ssim_3rows




In [19]:
# show the comparisions for 4 files, first two from files fig1a, 2 simulations and 2 exps

selected_files= FILES_FIG1A[:2]

In [10]:
# BPP for SDVAE images
# 32x 32 x4 images, 32 bit precision
# bpp = (32 * 32 * 4 channels * 32bits) / (256 x 256 x 3 x 8bits)  # bits per pixel calculation

compression_bpp= 256 * 256 * 3 * 8 / (32 * 32 * 4 * 32)  # bits per pixel calculation
print(f"Compression ratio for SDVAE: {compression_bpp}")

bpp_sdvae= 24/compression_bpp  # 24= 8 bits per channel for 3 channels rgb
print(f"BPP for SDVAE images: {bpp_sdvae}")



Compression ratio for SDVAE: 12.0
BPP for SDVAE images: 2.0


In [20]:

# compare the bpp results with jpeg compression 


def jpeg_bytes_and_bpp(img_path, quality):
    im = Image.open(img_path).convert("RGB")
    H, W = im.height, im.width
    buf = io.BytesIO()
    im.save(buf, "JPEG", quality=quality, subsampling=0, optimize=True,
            progressive=False)  # don't pass exif/icc
    data = buf.getvalue()
    bpp = (len(data)*8)/(H*W)
    return data, bpp

# JPEG compression 

# image_file='/hpc/group/youlab/ks723/storage/Random_images/SSIM_comparisions/cameraman.png'
image_file=os.path.join(SIMULATED_FOLDER,selected_files[0]) # demo test with a simulated file
qualities = [80, 85, 86, 88, 90, 95]
bpp_dict = {}  # Dictionary to store quality: bpp pairs

for q in qualities:
    _, bpp = jpeg_bytes_and_bpp(image_file, quality=q)
    bpp_dict[q] = bpp  # Store quality as key, bpp as value
   

# Find the first quality level where BPP >= bpp_sdvae
quality_desired=None
bpp_desired = None
for q in qualities:
    if bpp_dict[q] >= bpp_sdvae:
        quality_desired = q
        bpp_desired = bpp_dict[q]
        break

print(f"\nTarget BPP for SDVAE: {bpp_sdvae:.4f}")
if quality_desired:
    print(f"Closest JPEG quality: {quality_desired}, BPP: {bpp_desired:.4f}")


# Downsampling and upsampling by resize function
# load folder
input_folder = '/hpc/group/youlab/ks723/storage/Random_images/SSIM_comparisions'  # Path with original images
filename='cameraman.png'

value_downsampling= 2.5 # found by rough checking ~ 2.0

def bpp_of_downscaled(path_in, H_orig, W_orig, scale):
    img = cv2.imread(path_in, cv2.IMREAD_COLOR)
    h, w = int(H_orig/scale), int(W_orig/scale)
    small = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA)

    base, ext = os.path.splitext(path_in)
    out_path_small = f"{base}_downsampled_{scale}{ext}"
    cv2.imwrite(out_path_small,small)  # PNG default compression

    # write the upscaled version from the downsampled version
    upsampled= cv2.resize(small, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)

    out_path_resized=f"{base}_upsampled_{scale}{ext}"
    cv2.imwrite(out_path_resized, upsampled)

    bytes_on_disk = os.path.getsize(out_path_small)

    return (bytes_on_disk * 8) / (H_orig * W_orig)

bpp_downsampled=bpp_of_downscaled(os.path.join(input_folder,filename),256,256,value_downsampling)
print(f"BPP results with {value_downsampling}x downsampling is {bpp_downsampled}")




Target BPP for SDVAE: 2.0000
BPP results with 2.5x downsampling is 2.18212890625


In [24]:
bpp_dict

{80: 0.9703708726015323,
 85: 1.0560715980744457,
 86: 1.07959861685538,
 88: 1.1229235880398671,
 90: 1.172554071462472,
 95: 1.3621940470540375}

In [None]:


sim_data = preprocess_simulation_output_data(SIMULATED_FOLDER, 0, len(selected_files), img_filenames=selected_files)
sim_images = [data[0] for data in sim_data]  # Extract images (grayscale)
sim_images = np.array(sim_images)

# Experimental - use preprocessing function
exp_data = preprocess_experimental_backgroundwhite_rawfiles(EXPERIMENTAL_FOLDER, 0, len(selected_files), img_filenames=selected_files)
exp_images = exp_data  # This is already a list of images (RGB)
exp_images= np.array(exp_images)

# Use the VAE to compute the reconstructed version
# Convert grayscale to RGB for simulation and normalize
sim_rgb = np.stack([sim_images, sim_images, sim_images], axis=-1)  # Make RGB
sim_rgb = np.transpose(sim_rgb, (0, 3, 1, 2)) / 255.0  # Normalize and reorder

# Convert to tensor and encode
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
X = torch.Tensor(sim_rgb)

# Simple encoding (one batch at a time)
encoded_latents = []
for i in range(X.shape[0]):
    latent = encode_img(X[i:i+1].to(device))
    encoded_latents.append(latent.cpu())

# Combine results
latents_sim = torch.cat(encoded_latents, dim=0)
latents_scaled_sim = scale_latents(latents_sim)
reconstructed = decode_img(latents_sim)


exp_images = np.transpose(exp_images, (0, 3, 1, 2)) / 255.0  # Normalize and reorder
X = torch.Tensor(exp_images)

# Simple encoding (one batch at a time)  
encoded_latents = []
for i in range(X.shape[0]): 
    latent = encode_img(X[i:i+1].to(device))
    encoded_latents.append(latent.cpu())

# Combine results
latents_exp = torch.cat(encoded_latents, dim=0)
latents_scaled_exp = scale_latents(latents_exp)
reconstructed_exp = decode_img(latents_exp)
