In [None]:
# --- Mount Drive ---
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# --- Imports ---
import torch
import torchvision
import math
import matplotlib.pyplot as plt

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LATENT_SCALE = 0.18215  # SD convention

VAE_MODEL_ID = "stabilityai/sd-vae-ft-mse"

## Loading VAE

In [None]:
vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID).to(DEVICE).eval()

## Loading specifics packages

In [None]:
from Configurations import ReverseConfig

In [None]:
rcfg = ReverseConfig()

In [None]:
latents = torch.loa

In [None]:
# --- Config ---
SHARED_DRIVE_PATH = "/content/gdrive/MyDrive/HANDS_ON_GEN_AI/Datasets/COCO/test2017"

UNET_WEIGHTS = "/content/gdrive/MyDrive/HANDS_ON_GEN_AI/Datasets/last_model"
N_IMAGES = 8
IMAGE_SIZE = 256



# Loading packages and Model

In [None]:
# --- Project-specific imports (you already have these) ---
# Forward process
from subVP_processes import DiffusionProcess  # TODO: ensure this exposes a callable forward

from Configurations import ReverseConfig
# Score-based prediction
# from sampler import score_based_prediction   # TODO: your function signature
# Denoising
# from denoising_functions import denoise_images  # TODO: your denoising entrypoint
# UNet definition
from unet_model import UNetModel  # TODO: your UNet class

In [None]:
# # --- 1) Load X images and resize to 256 ---
# def load_images(folder, n, size):
#     root = Path(folder)
#     exts = {".jpg", ".jpeg", ".png"}
#     paths = [p for p in root.iterdir() if p.suffix.lower() in exts]
#     assert paths, f"No images in {folder}"
#     paths = random.sample(paths, k=min(n, len(paths)))

#     tf = transforms.Compose([
#         transforms.Resize((size, size), antialias=True),
#         transforms.ToTensor(),                 # [0,1]
#     ])
#     imgs = [tf(Image.open(p).convert("RGB")) for p in paths]
#     x = torch.stack(imgs, 0)                  # [B,3,256,256] in [0,1]
#     return x, paths

In [None]:
# --- 2) Load VAE ---
vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID).to(DEVICE).eval()

In [None]:
# --- 5) Load UNet weights ---
unet = UNetModel()  # TODO: pass your model hyperparams
state = torch.load(UNET_WEIGHTS, map_location=DEVICE)
# support both plain sd or wrapped dicts
state_dict = state.get("state_dict", state)
missing, unexpected = unet.load_state_dict(state_dict, strict=False)
unet = unet.to(DEVICE).eval()
print("UNet loaded. Missing:", missing, "Unexpected:", unexpected)

In [None]:
# --- 4) Forward process instance ---
fwd = DiffusionProcess()  # TODO: pass beta schedule params if required
cfg = ForwardConfig()
str_path_latent = fwd.run(cfg)

In [None]:
# --- Pipeline ---
with torch.no_grad():
    # 1) Load
    x_pixels, paths = load_images(SHARED_DRIVE_PATH, N_IMAGES, IMAGE_SIZE)
    x_pixels = x_pixels.to(DEVICE)

    # 2) VAE already loaded

    # 3) Encode â†’ latent_dist
    x_in = x_pixels * 2 - 1                     # [-1,1]
    latents_dist = vae.encode(x_in).latent_dist
    z0 = latents_dist.sample() * LATENT_SCALE   # [B,C,h,w]

    # 4) Apply forward process to latents
    #    Assume your API looks like: z_t = fwd.noise(z0, t)
    B = z0.shape[0]
    t_T = torch.ones(B, device=DEVICE)          # target time = 1.0
    z_t = fwd.noise(z0, t_T)                    # TODO: adapt to your exact method name

    # 6) Score-based prediction (predict noise on z_t)
    #    Assume API: pred_noise = score_based_prediction(unet, z_t, t_T, ...)
    pred_noise = score_based_prediction(unet, z_t, t_T)

    # 7) Denoise
    #    Assume API: z_denoised = denoise_images(z_t, pred_noise, t_T, ...)
    z_denoised = denoise_images(z_t, pred_noise, t_T)

    # 8) Decode with VAE
    x_dec = vae.decode(z_denoised / LATENT_SCALE).sample  # [-1,1]
    x_out = (x_dec.clamp(-1, 1) + 1) / 2                  # [0,1]

# --- 9) Print images ---
grid = torchvision.utils.make_grid(x_out.cpu(), nrow=min(4, x_out.size(0)))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.axis("off")
plt.show()


# Sampling

In [None]:
cfg = ReverseConfig()
denoised_images = DiffusionProcesses.run_reverse(cfg, unet)

In [None]:
decoder = vae.decode()

In [None]:
import math, torch, torchvision
import matplotlib.pyplot as plt
from torch import nn

from Configurations import ReverseConfig
from subVP_SDE import subVP_SDE

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
N_IMAGES = 8
IMAGE_SIZE = 256
LATENT_SCALE = 0.18215
VAE_MODEL_ID = "stabilityai/sd-vae-ft-mse"

In [None]:
model = globals().get("unet", None)
if not isinstance(model, nn.Module):
    raise RuntimeError("Model not found. Make sure your UNet is loaded into a variable named `unet`.")

In [None]:
# ---- config ----
rcfg = ReverseConfig(
    t0=1.0, t1=0.0,
    beta_min=0.1, beta_max=20.0,
    N=1000,                      # reverse steps
    schedule="linear",
    device=DEVICE, dtype=DTYPE,
    shape=(N_IMAGES, 4, IMAGE_SIZE // 8, IMAGE_SIZE // 8),
    rev_type="sde",              # switch to "ode" for probability-flow ODE
)

In [None]:
# Where to save denoised latents
latents_out = rcfg.output_path  # e.g., "latents_denoised.pt"

In [None]:
# 1) Run reverse to obtain denoised latents
latents = DiffusionProcesses.run_reverse(rcfg, model)
torch.save(latents, latents_out)

In [None]:
# 2) Decode with VAE and display
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID).to(DEVICE).eval()

In [None]:
with torch.no_grad():
    imgs = vae.decode(latents / LATENT_SCALE).sample  # [-1,1]

In [None]:
imgs = (imgs.clamp(-1, 1) + 1) / 2.0  # [0,1]
B = imgs.size(0)
nrow = int(math.sqrt(B)) if int(math.sqrt(B))**2 == B else min(B, 8)

grid = torchvision.utils.make_grid(imgs, nrow=nrow, padding=2)
plt.figure(figsize=(nrow * 2.0, math.ceil(B / nrow) * 2.0))
plt.axis("off")
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.show()
print(f"Denoised latents saved to: {latents_out}")
