In [None]:
# ---------------------------------------------------------------------------
# 1. CONFIGURATION CELL
# ---------------------------------------------------------------------------
# Set to True to load the Img2Img pipeline, False for the Txt2Img pipeline.
# This ensures we only use VRAM for one pipeline at a time.

USE_IMG2IMG = True

In [None]:
import os, sys

p = "/kaggle/input/sage-zrok-token/.zrok_api_key"
zrok_token = None

if os.path.isfile(p):
    with open(p, "r", encoding="utf-8", errors="ignore") as f:
        zrok_token = f.read().strip()

if not zrok_token:
    print("❌ Token not found or empty:", p)
    sys.exit(1)

In [None]:
import os
import shutil

print("Setting up Models...")

# --- 1. Copy SDXL Turbo ---
source_sd = "/kaggle/input/sdxl-turbo"
dest_sd = "/kaggle/working/stabilityai/sdxl-turbo"

if os.path.exists(dest_sd):
    print(f"✓ SDXL Turbo already exists at {dest_sd}")
else:
    os.makedirs(os.path.dirname(dest_sd), exist_ok=True)
    print(f"  Copying SDXL Turbo... (this may take a moment)")
    try:
        shutil.copytree(source_sd, dest_sd)
        print(f"  ✓ Copied SDXL Turbo")
    except FileNotFoundError:
        print("❌ Could not find SDXL Turbo dataset.")

# --- 2. Copy ControlNet OpenPose SDXL ---
source_cn = "/kaggle/input/controlnet-openpose-sdxl"
dest_cn = "/kaggle/working/controlnet-openpose-sdxl"

if os.path.exists(dest_cn):
    print(f"✓ ControlNet OpenPose already exists at {dest_cn}")
else:
    print(f"  Copying ControlNet OpenPose...")
    try:
        shutil.copytree(source_cn, dest_cn)
        print(f"  ✓ Copied ControlNet")
    except FileNotFoundError:
        print(f"⚠️ ControlNet source not found at {source_cn}. Pipeline will fail if ControlNet is missing.")

print("✅ Models ready!")

In [None]:
# ---------------------------------------------------------------------------
# LOAD PIPELINE WITH CONTROLNET
# ---------------------------------------------------------------------------
import torch
from diffusers import (
    ControlNetModel,
    StableDiffusionXLControlNetPipeline,
    StableDiffusionXLControlNetImg2ImgPipeline
)

# --- Paths ---
MODEL_PATH = '/kaggle/working/stabilityai/sdxl-turbo'
CONTROLNET_PATH = '/kaggle/working/controlnet-openpose-sdxl'

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = None

print(f"Loading SDXL Turbo + ControlNet | Mode: {'IMG2IMG' if USE_IMG2IMG else 'TXT2IMG'}...")

# 1. Load ControlNet Model
print("Loading ControlNet...")
try:
    controlnet = ControlNetModel.from_pretrained(
        CONTROLNET_PATH,
        torch_dtype=torch.float16
    )
except Exception as e:
    print(f"❌ Failed to load ControlNet from {CONTROLNET_PATH}")
    raise e

# 2. Load Main Pipeline
if USE_IMG2IMG:
    pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
        MODEL_PATH,
        controlnet=controlnet,
        torch_dtype=torch.float16,
    )
else:
    pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
        MODEL_PATH,
        controlnet=controlnet,
        torch_dtype=torch.float16,
    )

pipe = pipe.to(device)

# 3. Speed Optimizations
if hasattr(pipe, "safety_checker") and pipe.safety_checker is not None:
    pipe.safety_checker = lambda images, clip_input=None: (images, [False] * len(images))

print("✅ Pipeline loaded successfully!")

In [None]:
# ------------------------------------------------------------------------------
# LONG PROMPT HANDLER (MANUAL CHUNKING - NO WARNINGS)
# ------------------------------------------------------------------------------
import torch
import logging

# Suppress tokenizer warning
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)

def get_long_prompt_embeddings(pipe, prompt, negative_prompt):
    """
    Manually encodes prompts of ANY length for SDXL.
    Generates 'prompt_embeds' (content) and 'pooled_prompt_embeds' (style).
    """
    device = pipe.device
    
    def process_text(text):
        tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
        text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
        embeds_list = []
        pooled_embed = None

        for i, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
            input_ids = tokenizer(text, return_tensors="pt", truncation=False).input_ids.to(device)
            max_len = tokenizer.model_max_length
            chunk_size = max_len - 2
            
            chunks = []
            for k in range(0, input_ids.shape[-1], chunk_size):
                chunk = input_ids[:, k:k + chunk_size]
                bos = torch.tensor([tokenizer.bos_token_id], device=device).unsqueeze(0)
                eos = torch.tensor([tokenizer.eos_token_id], device=device).unsqueeze(0)
                chunk_padded = torch.cat([bos, chunk, eos], dim=1)
                if chunk_padded.shape[-1] < max_len:
                    pad = torch.full((1, max_len - chunk_padded.shape[-1]), tokenizer.pad_token_id, device=device)
                    chunk_padded = torch.cat([chunk_padded, pad], dim=1)
                chunk_padded = chunk_padded[:, :max_len]
                chunks.append(chunk_padded)

            layer_hidden_states = []
            for j, chunk in enumerate(chunks):
                with torch.no_grad():
                    output = text_encoder(chunk, output_hidden_states=True)
                    layer_hidden_states.append(output.hidden_states[-2])
                    if i == 1 and j == 0:
                        pooled_embed = output.text_embeds

            embeds_list.append(torch.cat(layer_hidden_states, dim=1))

        len_1 = embeds_list[0].shape[1]
        len_2 = embeds_list[1].shape[1]
        min_len = min(len_1, len_2)
        final_prompt_embeds = torch.cat([embeds_list[0][:, :min_len, :], embeds_list[1][:, :min_len, :]], dim=-1)
        return final_prompt_embeds, pooled_embed

    pos_embeds, pos_pooled = process_text(prompt)
    neg_embeds, neg_pooled = process_text(negative_prompt)

    p_len = pos_embeds.shape[1]
    n_len = neg_embeds.shape[1]

    if p_len > n_len:
        pad = torch.zeros((1, p_len - n_len, neg_embeds.shape[-1]), device=device, dtype=neg_embeds.dtype)
        neg_embeds = torch.cat([neg_embeds, pad], dim=1)
    elif n_len > p_len:
        neg_embeds = neg_embeds[:, :p_len, :]

    return pos_embeds, neg_embeds, pos_pooled, neg_pooled

In [None]:
# -----------------------------
# GENERATION FUNCTION
# Handles Turbo, ControlNet + Resize Fix
# -----------------------------
import random
from PIL import Image

def generate_sdxl_turbo(
    pipe,
    prompt: str,
    negative_prompt: str = "",
    init_image = None,
    control_image = None,
    strength: float = 0.5,
    num_inference_steps: int = 2,
    guidance_scale: float = 0.0,
    controlnet_scale: float = 0.5,
    seed: int = None,
    height: int = 512,
    width: int = 512
):
    device = pipe.device
    if seed is None:
        seed = int(torch.randint(0, 2**31 - 1, (1,)).item())
    generator = torch.Generator(device=device).manual_seed(seed)

    # 1. Get Long Prompt Embeddings
    pos_emb, neg_emb, pos_pool, neg_pool = get_long_prompt_embeddings(pipe, prompt, negative_prompt)

    # 2. Setup Dimensions & Control Image
    # In Img2Img, target size is determined by init_image
    target_w = init_image.width if init_image else width
    target_h = init_image.height if init_image else height
    
    pipeline_control_image = control_image
    effective_control_scale = float(controlnet_scale)

    if pipeline_control_image is None:
        # Dummy black image if no control provided
        pipeline_control_image = Image.new("RGB", (target_w, target_h), (0, 0, 0))
        effective_control_scale = 0.0
    
    # [CRITICAL FIX] Resize ControlNet image to match Init Image exactly
    # This prevents the "tensor size mismatch (128 vs 64)" error
    if pipeline_control_image.size != (target_w, target_h):
        pipeline_control_image = pipeline_control_image.resize((target_w, target_h), Image.LANCZOS)

    # 3. Prepare Arguments
    call_kwargs = {
        "prompt_embeds": pos_emb,
        "pooled_prompt_embeds": pos_pool,
        "negative_prompt_embeds": neg_emb,
        "negative_pooled_prompt_embeds": neg_pool,
        "num_inference_steps": int(num_inference_steps),
        "guidance_scale": float(guidance_scale),
        "generator": generator,
        "controlnet_conditioning_scale": effective_control_scale,
    }

    # 4. Mode Specific Arguments
    if USE_IMG2IMG:
        if init_image is None:
            raise ValueError("Img2Img mode requires init_image.")
        call_kwargs["image"] = init_image
        call_kwargs["control_image"] = pipeline_control_image
        call_kwargs["strength"] = float(strength)
    else:
        call_kwargs["image"] = pipeline_control_image
        call_kwargs["height"] = int(height)
        call_kwargs["width"] = int(width)

    # 5. Execute
    with torch.inference_mode():
        output = pipe(**call_kwargs)
        image = output.images[0]
    
    return image

In [None]:
# --------------------------
# FastAPI app init
# --------------------------
from fastapi import FastAPI
import nest_asyncio
import uvicorn

app = FastAPI()
nest_asyncio.apply()

In [None]:
# ---------------------------
# SYNCHRONOUS ENDPOINT
# ---------------------------
from fastapi import Form, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
from io import BytesIO
from PIL import Image

@app.post("/generate")
async def generate(
    prompt: str = Form(...),
    negative_prompt: str = Form(""),
    seed: int = Form(None),
    num_inference_steps: int = Form(2),
    guidance_scale: float = Form(0.0),
    height: int = Form(512),
    width: int = Form(512),
    strength: float = Form(0.5),
    controlnet_scale: float = Form(0.5),
    init_image: UploadFile = File(None),
    control_image: UploadFile = File(None),
):
    # --- Validation ---
    if USE_IMG2IMG and init_image is None:
        return JSONResponse(status_code=400, content={"error": "Notebook is in Img2Img mode, but no init_image provided."})
    if not USE_IMG2IMG and init_image is not None:
        return JSONResponse(status_code=400, content={"error": "Notebook is in Txt2Img mode, but init_image was provided."})

    try:
        # --- Read Images ---
        pil_init = None
        if init_image:
            pil_init = Image.open(BytesIO(await init_image.read())).convert("RGB")
            
        pil_control = None
        if control_image:
            pil_control = Image.open(BytesIO(await control_image.read())).convert("RGB")

        # --- Generate ---
        generated_image = generate_sdxl_turbo(
            pipe=pipe,
            prompt=prompt,
            negative_prompt=negative_prompt,
            init_image=pil_init,
            control_image=pil_control,
            strength=strength,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            controlnet_scale=controlnet_scale,
            seed=seed,
            height=height,
            width=width
        )

        # --- Return Stream ---
        buffer = BytesIO()
        generated_image.save(buffer, format="PNG")
        buffer.seek(0)
        return StreamingResponse(buffer, media_type="image/png")

    except Exception as e:
        import traceback
        traceback.print_exc()
        return JSONResponse(status_code=500, content={"error": str(e)})

In [None]:
# Download zrok v1.1.3 (latest)
!wget https://github.com/openziti/zrok/releases/download/v1.1.3/zrok_1.1.3_linux_amd64.tar.gz
!tar -xzf zrok_1.1.3_linux_amd64.tar.gz
!chmod +x zrok

In [None]:
# Enable (automatic migration from 0.4)
!./zrok enable --headless "$zrok_token"

In [None]:
import uvicorn
import threading

def run_uvicorn():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

# Start in background thread
threading.Thread(target=run_uvicorn, daemon=True).start()

In [None]:
import subprocess
import time

def start_zrok_tunnel(port=8000):
    # Start the tunnel
    process = subprocess.Popen([
        "./zrok", "share", "public", f"localhost:{port}", "--headless"
    ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Give it a moment to start
    time.sleep(3)

    # Check agent status to get the URL
    status_process = subprocess.run([
        "./zrok", "agent", "status"
    ], capture_output=True, text=True)

    print("Agent Status:")
    print(status_process.stdout)

    return process

# Start the tunnel
tunnel_process = start_zrok_tunnel(8000)
print("Zrok tunnel started! Check the agent status above for your public URL.")

In [None]:
import time

print("Server and zrok tunnel are running. Keeping the notebook alive...")

try:
    while True:
        time.sleep(60)
except KeyboardInterrupt:
    print("Shutting down.")