In [9]:
import torch
import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from diffusers import AutoencoderKL



In [40]:

class Reconstructor():
    def __init__(self,device,use_fp16,*args,**kwargs):
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.use_fp16 = use_fp16 and "cuda" in self.device
        self.dtype = torch.float16 if self.use_fp16 else torch.float32

    def crop_by_eight(self,img):
        w, h = img.size
        new_w, new_h = (w // 8) * 8, (h // 8) * 8
        if new_w != w or new_h != h:
            img = img.crop(((w-new_w)//2, (h-new_h)//2, (w-new_w)//2 + new_w, (h-new_h)//2 + new_h))
        return img 

    def preprocess(self):
        pass
        
    def reconstruct(self):
        pass

    

In [51]:

class VAEReconstructor(Reconstructor):
    def __init__(self, model_id="stabilityai/sd-vae-ft-mse", device=None,use_fp16=True,*args,**kwargs):
        super().__init__(device,use_fp16,*args,**kwargs)
        
        print(f"Loading VAE from '{model_id}' to {self.device}...")
        try:
            self.vae = AutoencoderKL.from_pretrained(model_id, torch_dtype=self.dtype)
        except OSError:
            self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=self.dtype)
            
        self.vae.to(self.device)
        self.vae.eval()

    def preprocess(self, img):
        # Crop to multiple of 8
        img = self.crop_by_eight(img)
        # Normalize to [-1, 1]
        x = torch.from_numpy(np.array(img)).float() / 255.0
        x = x.permute(2, 0, 1).unsqueeze(0)
        return img, (2.0 * x - 1.0).to(self.device, dtype=self.dtype)
        
    @torch.no_grad()
    def reconstruct(self, img):
        original_cropped, input_tensor = self.preprocess(img)
        
        latents = self.vae.encode(input_tensor).latent_dist.sample()
        decoded = self.vae.decode(latents).sample

        decoded = (decoded / 2 + 0.5).clamp(0, 1)
        decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
        decoded = (decoded * 255).round().astype("uint8")[0]

        return original_cropped, Image.fromarray(decoded)


In [46]:
reconstructor = VAEReconstructor()

Loading VAE from 'stabilityai/sd-vae-ft-mse' to cpu...


In [48]:

img = Image.open("../../../../data/generated-image-detection/train/real/AADB_newtest/0.050_farm1_269_19943988589_646a5d1dda_b.jpg")
real_img, recon_img = reconstructor.reconstruct(img)

In [58]:
real_img.save("../src/real.png")

In [59]:
recon_img.save("../src/recon.png")

In [55]:
# import torch
# import numpy as np
# from PIL import Image
# from diffusers import AutoencoderKL, VQModel

# # --- 1. SDXL Reconstructor (Handles the FP16 'Black Image' Bug) ---
# class SDXLReconstructor(Reconstructor):
#     def __init__(self, model_id="stabilityai/sdxl-vae", device=None, use_fp16=True, *args, **kwargs):
#         # SDXL TRAP: The original 'stabilityai/sdxl-vae' often outputs NaNs (black images) in FP16.
#         # If the user requests FP16, we automatically switch to the community-standard fix.
#         if use_fp16 and model_id == "stabilityai/sdxl-vae":
#             print("Warning: Switching to 'madebyollin/sdxl-vae-fp16-fix' to prevent NaNs in FP16.")
#             model_id = "madebyollin/sdxl-vae-fp16-fix"
            
#         super().__init__(device, use_fp16, *args, **kwargs)
        
#         print(f"Loading SDXL VAE from '{model_id}'...")
#         self.vae = AutoencoderKL.from_pretrained(model_id, torch_dtype=self.dtype)
#         self.vae.to(self.device)
#         self.vae.eval()

#     def preprocess(self, img):
#         # SDXL uses the same preprocessing as SD1.5 ([-1, 1] range)
#         img = self.crop_by_eight(img)
#         x = torch.from_numpy(np.array(img)).float() / 255.0
#         x = x.permute(2, 0, 1).unsqueeze(0)
#         return img, (2.0 * x - 1.0).to(self.device, dtype=self.dtype)

#     @torch.no_grad()
#     def reconstruct(self, img):
#         original_cropped, input_tensor = self.preprocess(img)
        
#         # SDXL Encoder works the same as standard KL
#         latents = self.vae.encode(input_tensor).latent_dist.sample()
        
#         # Decoding
#         decoded = self.vae.decode(latents).sample

#         # Post-process (Standard clamping)
#         decoded = (decoded / 2 + 0.5).clamp(0, 1)
#         decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
#         decoded = (decoded * 255).round().astype("uint8")[0]

#         return original_cropped, Image.fromarray(decoded)


# # --- 2. VQ-GAN Reconstructor (Handles Discrete Vector Quantization) ---
# class VQReconstructor(Reconstructor):
#     def __init__(self, model_id="microsoft/vq-diffusion-ithq", device=None, use_fp16=True, *args, **kwargs):
#         super().__init__(device, use_fp16, *args, **kwargs)
        
#         print(f"Loading VQModel from '{model_id}'...")
#         # Note: We use VQModel class here, not AutoencoderKL
#         self.vae = VQModel.from_pretrained(model_id, subfolder="vqvae", torch_dtype=self.dtype)
#         self.vae.to(self.device)
#         self.vae.eval()

#     def preprocess(self, img):
#         # VQ-GANs usually behave better with standard [-1, 1] scaling
#         img = self.crop_by_eight(img)
#         x = torch.from_numpy(np.array(img)).float() / 255.0
#         x = x.permute(2, 0, 1).unsqueeze(0)
#         return img, (2.0 * x - 1.0).to(self.device, dtype=self.dtype)

#     @torch.no_grad()
#     def reconstruct(self, img):
#         original_cropped, input_tensor = self.preprocess(img)
        
#         # DIFFERENCE: VQModel encoder output is not a distribution (.sample()), 
#         # it is a direct latent representation (often .latents)
#         encoded_output = self.vae.encode(input_tensor)
        
#         # Some VQ implementations return a tuple or object, handle specifically:
#         if hasattr(encoded_output, "latents"):
#             latents = encoded_output.latents
#         else:
#             latents = encoded_output # Fallback for some older diffusers versions

#         decoded = self.vae.decode(latents).sample

#         decoded = (decoded / 2 + 0.5).clamp(0, 1)
#         decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
#         decoded = (decoded * 255).round().astype("uint8")[0]

#         return original_cropped, Image.fromarray(decoded)

In [54]:
# reconstructor = VQReconstructor()