# Sampling

In [None]:
import math, torch, torchvision
import matplotlib.pyplot as plt
from torch import nn

from Configurations import ReverseConfig
from subVP_SDE import subVP_SDE

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
N_IMAGES = 8
IMAGE_SIZE = 256
LATENT_SCALE = 0.18215
VAE_MODEL_ID = "stabilityai/sd-vae-ft-mse"

In [None]:
model = globals().get("unet", None)
if not isinstance(model, nn.Module):
    raise RuntimeError("Model not found. Make sure your UNet is loaded into a variable named `unet`.")

In [None]:
# ---- config ----
rcfg = ReverseConfig(
    t0=1.0, t1=0.0,
    beta_min=0.1, beta_max=20.0,
    N=1000,                      # reverse steps
    schedule="linear",
    device=DEVICE, dtype=DTYPE,
    shape=(N_IMAGES, 4, IMAGE_SIZE // 8, IMAGE_SIZE // 8),
    rev_type="sde",              # switch to "ode" for probability-flow ODE
)

In [None]:
# Where to save denoised latents
latents_out = rcfg.output_path  # e.g., "latents_denoised.pt"

In [None]:
# 1) Run reverse to obtain denoised latents
latents = DiffusionProcesses.run_reverse(rcfg, model)
torch.save(latents, latents_out)

In [None]:
# 2) Decode with VAE and display
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID).to(DEVICE).eval()

In [None]:
with torch.no_grad():
    imgs = vae.decode(latents / LATENT_SCALE).sample  # [-1,1]

In [None]:
imgs = (imgs.clamp(-1, 1) + 1) / 2.0  # [0,1]
B = imgs.size(0)
nrow = int(math.sqrt(B)) if int(math.sqrt(B))**2 == B else min(B, 8)

grid = torchvision.utils.make_grid(imgs, nrow=nrow, padding=2)
plt.figure(figsize=(nrow * 2.0, math.ceil(B / nrow) * 2.0))
plt.axis("off")
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.show()
print(f"Denoised latents saved to: {latents_out}")
