In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"   

import torch 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark   = False
torch.use_deterministic_algorithms(True)


In [None]:
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3

from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from omegaconf import OmegaConf
from cldm.cldm import ControlLDM  # Your model class
import tempfile, imageio

from cldm.config import CKPT_PATH

# modified code

yaml_config = "./models/cldm_v15.yaml"           # YAML configuration file
ckpt_path= CKPT_PATH


# ------------------------------
# 1. Load the Model from Checkpoint
# ------------------------------
config_yaml = OmegaConf.load(yaml_config)
params = OmegaConf.to_container(config_yaml.model.params, resolve=True)
model = ControlLDM.load_from_checkpoint(ckpt_path, **params)
model = model.cuda()
model.eval()

# torch.manual_seed(seed_value)
# torch.cuda.manual_seed_all(seed_value)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark   = False
# torch.use_deterministic_algorithms(True)

# Set up the DDIM sampler.
ddim_sampler = DDIMSampler(model)

def latent_to_rgba(z_bchw: torch.Tensor) -> np.ndarray:
    # z_bchw: [1,4,h,w]
    z = z_bchw[0].float()                # [4,h,w]
    c, h, w = z.shape
    assert c == 4, "expected 4-channel latents"
    chans = []
    for k in range(4):
        m = z[k]
        m = (m - m.min()) / (m.max() - m.min() + 1e-8)   # 0..1 per-channel
        chans.append(m)
    rgba = torch.stack(chans, 0)                          # [4,h,w]
    rgba = (rgba.permute(1,2,0) * 255).byte().cpu().numpy()  # (h,w,4) uint8
    return rgba

def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
    with torch.no_grad():
        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape

        print(type(img))
        # detected_map = np.zeros_like(img, dtype=np.uint8)
        # detected_map[np.min(img, axis=2) < 127] = 255

        control = torch.from_numpy(img.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
        un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
        shape = (4, H // 8, W // 8)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=True)

        model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                     shape, cond, verbose=False, eta=eta,
                                                     log_every_t=1,
                                                     unconditional_guidance_scale=scale,
                                                     unconditional_conditioning=un_cond)
        
        print(">>> intermediates type:", type(intermediates))
        if isinstance(intermediates, dict):
            print(">>> keys:", list(intermediates.keys()))
        for k,v in intermediates.items():
            print(f"  {k!r}: {type(v)} with length {len(v)}; first element type {type(v[0])}")

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        x_samples = model.decode_first_stage(samples)
        x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

        results = [x_samples[i] for i in range(num_samples)]


        # decode each x_inter latent 
        latent_list = intermediates.get('x_inter', [])

        # two tracks: [sample0_frames, sample1_frames]
        frames_dec = [[], []]
        frames_lat = [[], []]

        for z in latent_list:
            # sample 0
            if z.shape[0] >= 1:
                z0 = z[0:1].to(model.device)                       # [1,4,h,w]
                frames_lat[0].append(latent_to_rgba(z0))
                img0_t = model.decode_first_stage(z0)[0]            # CHW, [-1,1]
                img0 = ((img0_t.clamp(-1,1)+1)/2).permute(1,2,0).mul(255).byte().cpu().numpy()
                frames_dec[0].append(img0)

            # sample 1
            if z.shape[0] >= 2:
                z1 = z[1:2].to(model.device)
                frames_lat[1].append(latent_to_rgba(z1))
                img1_t = model.decode_first_stage(z1)[0]
                img1 = ((img1_t.clamp(-1,1)+1)/2).permute(1,2,0).mul(255).byte().cpu().numpy()
                frames_dec[1].append(img1)

        def write_gif(frames):
            tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False); tmp.close()
            imageio.mimsave(tmp.name, frames, format="GIF-PIL", duration=1/12, loop=0)
            return tmp.name

        s0_dec = write_gif(frames_dec[0]); s0_lat = write_gif(frames_lat[0])
        if len(frames_dec[1]) and len(frames_lat[1]):
            s1_dec = write_gif(frames_dec[1]); s1_lat = write_gif(frames_lat[1])
        else:
            s1_dec, s1_lat = s0_dec, s0_lat
        downloads = [s0_dec, s0_lat, s1_dec, s1_lat]

    return results, s0_dec, s0_lat, s1_dec, s1_lat, downloads

    # return results ,tmp_dec.name, tmp_lat.name


block = gr.Blocks().queue()

with block:
    with gr.Row():
        gr.Markdown("## Control Stable Diffusion between two Simulated Parameters")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source='upload', type="numpy")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button(label="Run")
            with gr.Accordion("Advanced options", open=False):
                num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
                strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
                guess_mode = gr.Checkbox(label='Guess Mode', value=False) 
                ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                eta = gr.Number(label="eta (DDIM)", value=0.0)
                a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
                n_prompt = gr.Textbox(label="Negative Prompt",
                                      value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
            video_s0_dec = gr.Video(label="Sample 1 • Decoded ▶️", format="gif")
            video_s0_lat = gr.Video(label="Sample 1 • Latent ▶️",  format="gif")
            video_s1_dec = gr.Video(label="Sample 2 • Decoded ▶️", format="gif")
            video_s1_lat = gr.Video(label="Sample 2 • Latent ▶️",  format="gif")
            downloads     = gr.Files(label="Download videos")   
    ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery, video_s0_dec, video_s0_lat, video_s1_dec, video_s1_lat, downloads])


block.queue().launch(debug='True',server_name='0.0.0.0',share=False)

  from .autonotebook import tqdm as notebook_tqdm


[2025-11-11 21:13:43,031] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
logging improved.
No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Running on local URL:  http://0.0.0.0:7860

To create a public link, set `share=True` in `launch()`.


<class 'numpy.ndarray'>


Global seed set to 729397049


Data shape for DDIM sampling is (1, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:13<00:00,  3.71it/s]


>>> intermediates type: <class 'dict'>
>>> keys: ['x_inter', 'pred_x0']
  'x_inter': <class 'list'> with length 51; first element type <class 'torch.Tensor'>
  'pred_x0': <class 'list'> with length 51; first element type <class 'torch.Tensor'>


Global seed set to 729397049


<class 'numpy.ndarray'>
Data shape for DDIM sampling is (2, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:11<00:00,  4.42it/s]


>>> intermediates type: <class 'dict'>
>>> keys: ['x_inter', 'pred_x0']
  'x_inter': <class 'list'> with length 51; first element type <class 'torch.Tensor'>
  'pred_x0': <class 'list'> with length 51; first element type <class 'torch.Tensor'>




Keyboard interruption in main thread... closing server.


