In [None]:
!pip install -q gradio

In [None]:
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from PIL import Image

# --- CONFIGURATION ---
# Path to your saved model in Drive
MODEL_PATH = "/content/drive/MyDrive/controlnet_model_output_v1"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading model from {MODEL_PATH}...")

# 1. Load the Model (Cached globally so it doesn't reload every time)
controlnet = ControlNetModel.from_pretrained(MODEL_PATH, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
).to(device)

# Speed up inference
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload() # Saves VRAM

# 2. Define the Generation Function
def generate(cond_image, prompt, num_steps, guidance_scale, seed):
    if cond_image is None:
        return None

    # Pre-processing: Resize to 512x512 (Standard for SD v1.5)
    # This prevents dimension errors
    input_image = Image.fromarray(cond_image).resize((512, 512))

    # Handle Seed
    if seed == -1:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(int(seed))

    # Generate
    output = pipe(
        prompt,
        image=input_image,
        num_inference_steps=int(num_steps),
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]

    return output

# 3. Create the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ðŸŽ¨ Saliency Mask Guided ControlNet")
    gr.Markdown("Upload your **contextual saliency mask** and type a prompt to generate an image.")

    with gr.Row():
        with gr.Column():
            # Inputs
            input_img = gr.Image(label="Contextual Saliency Mask", type="numpy")
            prompt = gr.Textbox(label="Prompt", placeholder="e.g. A cake with a number on it")
            # neg_prompt = gr.Textbox(label="Negative Prompt", value="low quality, blurry, distorted")

            with gr.Accordion("Advanced Options", open=False):
                steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=20, step=1)
                guidance = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.5, step=0.5)
                seed = gr.Number(label="Seed (-1 for random)", value=-1)

            run_btn = gr.Button("Generate Image", variant="primary")

        with gr.Column():
            # Output
            output_img = gr.Image(label="Generated Result")

    # Connect inputs to function
    run_btn.click(
        fn=generate,
        inputs=[input_img, prompt, steps, guidance, seed],
        outputs=output_img
    )

# 4. Launch
# share=True creates a public link (e.g., https://xxxx.gradio.live) accessible from anywhere
demo.launch(share=True, debug=True)

Loading model from /content/drive/MyDrive/controlnet_model_output_v1...


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
  with gr.Blocks(theme=gr.themes.Soft()) as demo:


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://084fbc54c810abd63d.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://084fbc54c810abd63d.gradio.live


