In [None]:
from share import *
import config
import einops
import gradio as gr
import numpy as np
import torch
import random


from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3


from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from omegaconf import OmegaConf
from cldm.cldm import ControlLDM  # Your model class

# Import custom functions:l=
from image_to_seed import image_to_seed
from prediction_seedtosim import prediction as predict_from_seed


# ------------------------------
# Load the ControlNet model from checkpoint
# ------------------------------
yaml_config = "./models/cldm_v15.yaml"           # YAML configuration file
# ckpt_path = "/hpc/dctrl/ks723/Huggingface_repos/ControlNet_repo/controlnet_repo/lightning_logs/version_25478352/checkpoints/epoch=4-step=94129.ckpt"  # sim1tosim2
ckpt_path = '/hpc/dctrl/ks723/Huggingface_repos/ControlNet_repo/controlnet_repo/lightning_logs/version_25478850/checkpoints/epoch=4-step=51124.ckpt'

config_yaml = OmegaConf.load(yaml_config)
params = OmegaConf.to_container(config_yaml.model.params, resolve=True)
model = ControlLDM.load_from_checkpoint(ckpt_path, **params)
model = model.cuda()

ddim_sampler = DDIMSampler(model)

def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
    with torch.no_grad():
  
        print((input_image.shape))

        img = resize_image(HWC3(input_image[:, :, 0]), 256)
    
        


        inverted_img= 255-img
        img=inverted_img


        # --- Step 2: Image-to-Seed conversion ---
        # Convert the processed canvas into a seed image.
        
        seed_img = image_to_seed(img, num_dots=50, min_area=40)
        
        # --- Step 3: ResNet Emulation ---
    
        emulated_pattern = predict_from_seed(seed_img)
        predicted_img = emulated_pattern[0].permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC numpy image

        img=predicted_img

        H, W, C = img.shape

        
        # Use the processed image directly (without inverting the mask)
        control = torch.from_numpy(img.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        cond = {
            "c_concat": [control],
            "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]
        }
        un_cond = {
            "c_concat": None if guess_mode else [control],
            "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]
        }
        shape = (4, H // 8, W // 8)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=True)

        model.control_scales = ([strength * (0.825 ** float(12 - i)) for i in range(13)]
                                  if guess_mode else ([strength] * 13))
        samples, intermediates = ddim_sampler.sample(
            ddim_steps, num_samples, shape, cond, verbose=False, eta=eta,
            unconditional_guidance_scale=scale, unconditional_conditioning=un_cond
        )

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        x_samples = model.decode_first_stage(samples)
        x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
        x_samples = x_samples.clip(0, 255).astype(np.uint8)
        results = [x_samples[i] for i in range(num_samples)]
    return results


def create_canvas(w, h):
    return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255

# ------------------------------
# Gradio Interface Setup
# ------------------------------
block = gr.Blocks().queue()

with block:
    with gr.Row():
        gr.Markdown("## Control Stable Diffusion (Drawing Input)")
    with gr.Row():
        with gr.Column():
            # Use 'sketch' tool to allow drawing.

            canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=1)
            canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=1)
            create_button = gr.Button(label="Start", value='Open drawing canvas!')


            input_image = gr.Image(source='canvas', tool="sketch", type="numpy", shape=(256,256))
          
            create_button.click(fn=create_canvas, outputs=[input_image], inputs=[canvas_width, canvas_height])  #inputs=[canvas_width, canvas_height]

            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button(label="Run")

            
          
        

            with gr.Accordion("Advanced options", open=False):
                num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
                strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
                guess_mode = gr.Checkbox(label='Guess Mode', value=False)
                ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                eta = gr.Number(label="eta (DDIM)", value=0.0)
                a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
                n_prompt = gr.Textbox(label="Negative Prompt",
                                      value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
    ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])

block.queue().launch(debug=True, server_name='0.0.0.0', share=False)


  from .autonotebook import tqdm as notebook_tqdm


[2025-03-06 23:41:15,411] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
logging improved.
No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Running on local URL:  http://0.0.0.0:7860

To create a public link, set `share=True` in `launch()`.


(256, 256, 3)


Traceback (most recent call last):
  File "/hpc/dctrl/ks723/miniconda3/envs/test_pytorch_ipy_v2/lib/python3.10/site-packages/gradio/routes.py", line 337, in run_predict
    output = await app.get_blocks().process_api(
  File "/hpc/dctrl/ks723/miniconda3/envs/test_pytorch_ipy_v2/lib/python3.10/site-packages/gradio/blocks.py", line 1015, in process_api
    result = await self.call_function(
  File "/hpc/dctrl/ks723/miniconda3/envs/test_pytorch_ipy_v2/lib/python3.10/site-packages/gradio/blocks.py", line 833, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/hpc/dctrl/ks723/miniconda3/envs/test_pytorch_ipy_v2/lib/python3.10/site-packages/anyio/to_thread.py", line 28, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(func, *args, cancellable=cancellable,
  File "/hpc/dctrl/ks723/miniconda3/envs/test_pytorch_ipy_v2/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 818, in run_sync_in_worker_thread
    return await future
  File "/hpc

Keyboard interruption in main thread... closing server.


