In [None]:
A notebook for testing stable diffusion

In [None]:
PROJECT_ROOT = Path.cwd().parent
import sys, os
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
    print('Top of sys.path:', sys.path[0])

In [None]:
import argparse
import json
import os
import wandb
from functools import partial

import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from data_utils import  TargetReferenceDataset, collate_prompts
from torchvision.transforms.functional import to_pil_image
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from tqdm import tqdm
from ldm.models.diffusion.ddimcopy import DDIMSampler
from utils import get_models, print_trainable_parameters, set_seed

In [None]:
model_orig, sampler_orig, model, sampler = get_models(
    "../configs/stable-diffusion/v1-inference.yaml", "../models/sd-v1-4.ckpt", "cuda"
)

In [None]:
def generate_and_save_sd_images(
    model,
    sampler,
    prompt: str,
    device: torch.device,
    steps: int = 50,
    eta: float = 0.0,
    batch_size: int = 1,
    out_dir: str = "tmp",
    prefix: str = "unl_",
    start_code: torch.Tensor = None,   # optional noise tensor [B,4,64,64] (for 512x512)
):
    """
    Generates images with CFG from a CompVis SD model + DDIMSampler and saves them.

    - model: Stable Diffusion model (CompVis LDM style)
    - sampler: DDIMSampler(model)
    - prompt: text prompt
    - device: torch.device("cuda") or torch.device("cpu")
    - steps: DDIM steps
    - eta: DDIM eta (0.0 => deterministic)
    - batch_size: number of samples to generate
    - out_dir: folder to save into
    - prefix: file prefix, e.g., 'unl_'
    - start_code: optional start noise of shape [B, 4, H/8, W/8]; if None, sampled internally.
                  For 512×512 set shape to [B, 4, 64, 64].
    """
    # derive latent shape from start_code or default to 512×512
    if start_code is None:
        start_code = torch.randn(batch_size, 4, 64, 64, device=device)  # 512x512

    # freeze & eval for safety

    with torch.no_grad(), torch.autocast(device_type=device.type, enabled=(device.type == "cuda")):
        cond   = model.get_learned_conditioning([prompt] * start_code.shape[0])
        uncond = model.get_learned_conditioning([""] * start_code.shape[0])

        samples_latent, _ = sampler.sample(
            S=steps,
            conditioning={"c_crossattn": [cond]},
            batch_size=start_code.shape[0],
            shape=start_code.shape[1:],  # (4, H/8, W/8)
            verbose=False,
            unconditional_guidance_scale=7.5,                 # CFG scale; tweak if needed
            unconditional_conditioning={"c_crossattn": [uncond]},
            eta=eta,
            x_T=start_code,
        )

        # decode latents to [0,1] images
        imgs = model.decode_first_stage(samples_latent)       # [-1, 1]
        imgs = (imgs.clamp(-1, 1) + 1) / 2.0                 # [0, 1]

        # save
        out_path = Path(out_dir)
        out_path.mkdir(exist_ok=True)
        for i, im in enumerate(imgs.cpu()):
            im_u8 = (im.clamp(0, 1) * 255).round().to(torch.uint8)  # [3,H,W]
            to_pil_image(im_u8).save(out_path / f"{prefix}{i:04d}.png")

        print(f"Saved {len(imgs)} image(s) to {out_path}/ with prefix '{prefix}'")
        return imgs  # [B,3,H,W] in [0,1]

In [None]:
sampler = DDIMSampler(model)

In [None]:
img = generate_and_save_sd_images(
    model=model,
    sampler=sampler,
    prompt="a photo of the bird",
    device=torch.device("cuda"),
    steps=50,
    out_dir="tmp",
    prefix="orig_",
)

In [None]:
import torch
import matplotlib.pyplot as plt

x = img[0].detach().cpu()
if x.dtype == torch.uint8:
    arr = x.permute(1, 2, 0).numpy()      # HWC uint8 [0,255]
else:
    x = x.float()
    arr = x.permute(1, 2, 0).numpy()      # HWC float [0,1]

plt.imshow(arr)
plt.axis("off")