In [None]:
# ---------------------------------------------------------------------------
# CONFIGURATION CELL
# ---------------------------------------------------------------------------
# Set to True to load the Img2Img pipeline, False for the Txt2Img pipeline.
# This is the ONLY cell you need to change between your two notebook copies.

USE_IMG2IMG = True

In [None]:
import os, sys

p = "/kaggle/input/sage-zrok-token/.zrok_api_key"
zrok_token = None

if os.path.isfile(p):
    with open(p, "r", encoding="utf-8", errors="ignore") as f:
        zrok_token = f.read().strip()

if not zrok_token:
    print("❌ Token not found or empty:", p)
    sys.exit(1)

In [None]:
import os
import shutil

print("Setting up models...")

# --- Copy Animagine-XL 3.1 if not already present ---
source = "/kaggle/input/animagine-xl-3-1"
dest = "/kaggle/working/cagliostrolab/animagine-xl-3.1"

if os.path.exists(dest):
    print(f"✓ Animagine-XL v3.1 already exists at {dest}, skipping copy")
else:
    print(f"  Copying Animagine-XL v3.1... (this may take a moment)")
    shutil.copytree(source, dest)
    print(f"  ✓ Copied to {dest}")

# --- Copy ControlNet OpenPose for SDXL if not already present ---
controlnet_source = "/kaggle/input/controlnet-openpose-sdxl"
controlnet_dest = "/kaggle/working/controlnet-openpose-sdxl"

if os.path.exists(controlnet_dest):
    print(f"✓ ControlNet OpenPose SDXL already exists at {controlnet_dest}, skipping copy")
else:
    print(f"  Copying ControlNet OpenPose SDXL...")
    shutil.copytree(controlnet_source, controlnet_dest)
    print(f"  ✓ Copied to {controlnet_dest}")

print("✅ All models ready!")

In [None]:
import torch
from diffusers import (
    ControlNetModel,
    StableDiffusionXLControlNetPipeline,
    StableDiffusionXLControlNetImg2ImgPipeline,
    DPMSolverMultistepScheduler
)

# --- Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Model paths ---
local_model_path = "/kaggle/working/cagliostrolab/animagine-xl-3.1"
local_controlnet_path = "/kaggle/working/controlnet-openpose-sdxl"

# --- Load ControlNet (common for both pipelines) ---
print("Loading ControlNet model...")
controlnet = ControlNetModel.from_pretrained(
    local_controlnet_path,
    torch_dtype=torch.float16
).to(device)
print("✓ ControlNet loaded.")

pipe = None

# --- Conditionally load the main pipeline based on the USE_IMG2IMG flag ---
if not USE_IMG2IMG:
    print("Mode: Txt2Img. Loading pipeline...")
    pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
        local_model_path,
        controlnet=controlnet,
        torch_dtype=torch.float16
        #variant="fp16"
    ).to(device)
    print("✓ Txt2Img pipeline with ControlNet loaded.")
else:
    print("Mode: Img2Img. Loading pipeline...")
    pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
        local_model_path,
        controlnet=controlnet,
        torch_dtype=torch.float16
        #variant="fp16"
    ).to(device)
    print("✓ Img2Img pipeline with ControlNet loaded.")

# --- Configure Scheduler and Safety Checker ---
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config,
    algorithm_type="sde-dpmsolver++"
)
pipe.safety_checker = lambda images, clip_input=None: (images, [False] * len(images))


In [None]:
from PIL import Image
import torch
from typing import Optional
import random

# THIS FUNCTION IS A MODIFIED COPY FROM THE WORKING `sage-t2i-counterfeit-xl-25` NOTEBOOK
def generate_txt2img_xl(
    pipe, # StableDiffusionXLControlNetPipeline
    prompt: str,
    negative_prompt: Optional[str] = None,
    height: int = 1024,
    width: int = 1024,
    num_inference_steps: int = 40,
    guidance_scale: float = 7.5,
    seed: Optional[int] = None,
    control_image: Optional[Image.Image] = None,
    controlnet_scale: float = 0.5,
):
    """
    Generate an image using the exact same logic structure as the working SDXL notebook.
    """
    # Choose seed (identical logic)
    if seed is None:
        seed = int(torch.randint(0, 2**31 - 1, (1,)).item())

    # Create generator (identical logic)
    gen = torch.Generator(device="cuda").manual_seed(seed)

    # Determine if ControlNet is active
    use_controlnet = control_image is not None
    
    # The control image for the pipeline must be passed via the 'image' parameter.
    # If not using ControlNet, we pass a dummy black image to satisfy the pipeline's requirement.
    pipeline_image = control_image if use_controlnet else Image.new("RGB", (width, height), (0, 0, 0))

    # Build call kwargs (identical structure)
    call_kwargs = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "image": pipeline_image, # This is the control image for this pipeline
        "height": height,
        "width": width,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "controlnet_conditioning_scale": controlnet_scale if use_controlnet else 0.0,
        "generator": gen,
    }

    # Call the pipeline (identical logic)
    pipeline_output = pipe(**call_kwargs)

    # Extract image (identical logic)
    image = pipeline_output.images[0]

    return image
    

In [None]:
from PIL import Image
import torch
from typing import Optional, Tuple

def _ensure_rgb(img: Image.Image) -> Image.Image:
    if img.mode != "RGB":
        return img.convert("RGB")
    return img

def _resize_control_to_init(control: Image.Image, init: Image.Image, use_nearest: bool = True) -> Image.Image:
    """
    Resize control image to match init image's size.
    use_nearest: True for segmentation/pose/edges (preserve labels), False for photos.
    """
    if control.size == init.size:
        return control
    resample = Image.NEAREST if use_nearest else Image.BICUBIC
    return control.resize(init.size, resample=resample)

def _pad_to_multiple(img: Image.Image, mult: int = 8, fill=(0,0,0)) -> Image.Image:
    w, h = img.size
    new_w = ((w + mult - 1) // mult) * mult
    new_h = ((h + mult - 1) // mult) * mult
    if new_w == w and new_h == h:
        return img
    new_img = Image.new(img.mode, (new_w, new_h), fill)
    # paste at top-left (0,0). If you prefer centered, compute offsets.
    new_img.paste(img, (0, 0))
    return new_img

def generate_img2img_xl(
    pipe,
    prompt: str,
    init_image: Image.Image,
    negative_prompt: Optional[str] = None,
    strength: float = 0.7,
    num_inference_steps: int = 40,
    guidance_scale: float = 7.5,
    seed: Optional[int] = None,
    control_image: Optional[Image.Image] = None,
    controlnet_scale: float = 0.5,
    control_is_map: bool = True,        # <-- set True for pose/edge/seg maps; False for photos
    divisible_by: int = 8,              # make sizes multiples of this (8 is typical)
):
    """
    Improved handling for control images: ensures same size and divisibility.
    """
    if seed is None:
        seed = int(torch.randint(0, 2**31 - 1, (1,)).item())
    gen = torch.Generator(device="cuda").manual_seed(seed)

    # Ensure RGB
    init_image = _ensure_rgb(init_image)
    pipeline_control_image = None
    use_controlnet = control_image is not None

    if use_controlnet:
        control_image = _ensure_rgb(control_image)
        # Resize control to match init image size (preserve labels if control_is_map)
        if control_image.size != init_image.size:
            print(f"[debug] control image size {control_image.size} != init image size {init_image.size}, resizing control -> init")
            control_image = _resize_control_to_init(control_image, init_image, use_nearest=control_is_map)

        # Optional: if either dimension is not divisible_by, pad both to same divisible size
        w_init, h_init = init_image.size
        new_w = ((w_init + divisible_by - 1) // divisible_by) * divisible_by
        new_h = ((h_init + divisible_by - 1) // divisible_by) * divisible_by
        if (new_w, new_h) != (w_init, h_init):
            print(f"[debug] padding images from {(w_init,h_init)} -> {(new_w,new_h)} to satisfy divisible_by={divisible_by}")
            init_image = _pad_to_multiple(init_image, mult=divisible_by, fill=(0,0,0))
            control_image = _pad_to_multiple(control_image, mult=divisible_by, fill=(0,0,0))

        pipeline_control_image = control_image
    else:
        # Create a black control image matching the init image, padded if needed
        w_init, h_init = init_image.size
        new_w = ((w_init + divisible_by - 1) // divisible_by) * divisible_by
        new_h = ((h_init + divisible_by - 1) // divisible_by) * divisible_by
        init_image = _pad_to_multiple(init_image, mult=divisible_by, fill=(0,0,0))
        pipeline_control_image = Image.new("RGB", init_image.size, (0, 0, 0))

    # Debug sizes
    print(f"[debug] final init image size: {init_image.size}, control image size: {pipeline_control_image.size}")

    call_kwargs = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "image": init_image,
        "control_image": pipeline_control_image,
        "strength": strength,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "controlnet_conditioning_scale": controlnet_scale if use_controlnet else 0.0,
        "generator": gen,
    }

    pipeline_output = pipe(**call_kwargs)
    image = pipeline_output.images[0]
    return image

In [None]:
from fastapi import FastAPI
import nest_asyncio
import uvicorn

app = FastAPI()
nest_asyncio.apply()  # allow running uvicorn in a notebook

In [None]:
# ===================================================================
# ASYNCHRONOUS TASK HANDLING SETUP
# ===================================================================
from fastapi import BackgroundTasks, HTTPException, Response
import uuid
import io
import time
from PIL import Image
import traceback

# In-memory "database" to store task status and results.
# This is perfect for a single-instance Kaggle notebook.
# Format: { "task_id": {"status": "...", "result": ..., "error": ...} }
tasks = {}


In [None]:
from PIL import Image
import io
import traceback

def run_generation_task(task_id: str, params: dict):
    """
    Performs the heavy lifting of image generation, updating the global `tasks` dict.
    """
    try:
        tasks[task_id]["status"] = "PROCESSING"

        # Re-create PIL images from bytes if they exist
        init_img_bytes = params.pop("init_image_bytes", None)
        control_img_bytes = params.pop("control_image_bytes", None)
        
        init_image = Image.open(io.BytesIO(init_img_bytes)).convert("RGB") if init_img_bytes else None
        control_image = Image.open(io.BytesIO(control_img_bytes)).convert("RGB") if control_img_bytes else None

        # Call the appropriate generation function
        if init_image:
            # Img2Img workflow
            generated_image = generate_img2img_xl(
                pipe=pipe,
                init_image=init_image,
                control_image=control_image,
                prompt=params["prompt"],
                negative_prompt=params["negative_prompt"],
                strength=params["strength"],
                num_inference_steps=params["num_inference_steps"],
                guidance_scale=params["guidance_scale"],
                seed=params["seed"],
                controlnet_scale=params.get("controlnet_scale", 0.5)
            )
        else:
            # Txt2Img workflow
            generated_image = generate_txt2img_xl(
                pipe=pipe,
                control_image=control_image,
                prompt=params["prompt"],
                negative_prompt=params["negative_prompt"],
                height=params["height"],
                width=params["width"],
                num_inference_steps=params["num_inference_steps"],
                guidance_scale=params["guidance_scale"],
                seed=params["seed"],
                controlnet_scale=params.get("controlnet_scale", 0.5)
            )
            
        # Process and store the result
        if generated_image.mode != 'RGB':
            generated_image = generated_image.convert("RGB")
            
        buffer = io.BytesIO()
        generated_image.save(buffer, format="PNG")
        image_bytes = buffer.getvalue()
        
        tasks[task_id]["status"] = "COMPLETED"
        tasks[task_id]["result"] = image_bytes

    except Exception as e:
        print(f"--- ASYNC TASK {task_id} FAILED ---")
        traceback.print_exc()
        tasks[task_id]["status"] = "FAILED"
        tasks[task_id]["error"] = str(e)

In [None]:
from fastapi import Form, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
import io  # Using `io` to match the working notebook
from PIL import Image
import traceback

@app.post("/generate")
async def generate(
    prompt: str = Form(...),
    negative_prompt: str = Form(None),
    seed: int = Form(None),
    num_inference_steps: int = Form(40),
    guidance_scale: float = Form(7.5),
    
    # Txt2Img params
    height: int = Form(1024),
    width: int = Form(1024),

    # Img2Img params
    strength: float = Form(0.7),
    init_image: UploadFile = File(None),  # If present, triggers Img2Img

    # ControlNet params
    control_image: UploadFile = File(None),
    controlnet_scale: float = Form(0.5),
):
    try:
        controlnet_img = None
        if control_image:
            contents = await control_image.read()
            controlnet_img = Image.open(io.BytesIO(contents)).convert("RGB")
        
        # --- Img2Img Workflow ---
        if init_image is not None:
            if not USE_IMG2IMG:
                raise RuntimeError("Received init_image, but notebook is in Txt2Img mode.")
            init_img = Image.open(io.BytesIO(await init_image.read())).convert("RGB")
            # NOTE: We assume generate_img2img_xl is also adapted similarly if you test it.
            generated_image = generate_img2img_xl(
                pipe=pipe, prompt=prompt, init_image=init_img, negative_prompt=negative_prompt,
                strength=strength, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
                seed=seed, control_image=controlnet_img, controlnet_scale=controlnet_scale
            )
        
        # --- Txt2Img Workflow ---
        else:
            if USE_IMG2IMG:
                raise RuntimeError("Did not receive init_image, but notebook is in Img2Img mode.")
            
            # Calling the generation function that mirrors the working notebook
            generated_image = generate_txt2img_xl(
                pipe=pipe, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width,
                num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed,
                control_image=controlnet_img, controlnet_scale=controlnet_scale
            )
    except Exception as e:
        traceback.print_exc()
        return JSONResponse(status_code=500, content={"error": str(e), "traceback": traceback.format_exc()})

    # --- [CRITICAL] ---
    # USING THE EXACT RESPONSE LOGIC FROM THE WORKING NOTEBOOK
    # `sage-t2i-counterfeit-xl-25/sage-t2i-counterfeit-xl-25.ipynb`
    # ---
    buffer = io.BytesIO()
    generated_image.save(buffer, format="PNG")
    buffer.seek(0)
    return StreamingResponse(buffer, media_type="image/png")



In [None]:
@app.post("/generate-async")
async def generate_async(
    background_tasks: BackgroundTasks,
    prompt: str = Form(...),
    negative_prompt: str = Form(None),
    seed: int = Form(None),
    num_inference_steps: int = Form(50),
    guidance_scale: float = Form(7.5),
    height: int = Form(1024),
    width: int = Form(1024),
    strength: float = Form(0.7),
    init_image: UploadFile = File(None),
    control_image: UploadFile = File(None),
    controlnet_scale: float = Form(0.5),
):
    """
    Starts a generation job in the background and immediately returns a task ID.
    """
    task_id = str(uuid.uuid4())
    
    # Read image contents now - NO conversion here, just raw bytes
    init_image_bytes = await init_image.read() if init_image else None
    control_image_bytes = await control_image.read() if control_image else None

    # Store initial task state
    tasks[task_id] = {"status": "PENDING"}

    # Bundle all parameters for the background worker
    params = {
        "prompt": prompt, "negative_prompt": negative_prompt, "seed": seed,
        "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale,
        "height": height, "width": width, "strength": strength,
        "controlnet_scale": controlnet_scale,
        "init_image_bytes": init_image_bytes, "control_image_bytes": control_image_bytes
    }
    
    # Add the long-running job to the background
    background_tasks.add_task(run_generation_task, task_id, params)
    
    # Immediately return the task ID
    return {"task_id": task_id, "status": "PENDING"}

In [None]:
@app.get("/status/{task_id}")
async def get_status(task_id: str):
    """
    Polls for the status of a task. If completed, returns the image.
    """
    task = tasks.get(task_id)
    if not task:
        raise HTTPException(status_code=404, detail="Task not found")

    status = task.get("status")
    
    if status == "COMPLETED":
        image_bytes = task.get("result")
        # Clean up the completed task to free memory
        del tasks[task_id]
        return Response(content=image_bytes, media_type="image/png")
    
    elif status == "FAILED":
        error_message = task.get("error", "Unknown error")
        # Clean up the failed task
        del tasks[task_id]
        raise HTTPException(status_code=500, detail=f"Task failed: {error_message}")
        
    else: # PENDING or PROCESSING
        return {"task_id": task_id, "status": status}
        

In [None]:
# Download zrok v1.1.3 (latest)
!wget https://github.com/openziti/zrok/releases/download/v1.1.3/zrok_1.1.3_linux_amd64.tar.gz
!tar -xzf zrok_1.1.3_linux_amd64.tar.gz
!chmod +x zrok

In [None]:
# Enable (automatic migration from 0.4)
!./zrok enable --headless "$zrok_token"

In [None]:
#!./zrok disable

In [None]:
import uvicorn
import threading

def run_uvicorn():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

# Start in background thread
threading.Thread(target=run_uvicorn, daemon=True).start()

In [None]:
import subprocess
import time

def start_zrok_tunnel(port=8000):
    # Start the tunnel
    process = subprocess.Popen([
        "./zrok", "share", "public", f"localhost:{port}", "--headless"
    ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Give it a moment to start
    time.sleep(3)

    # Check agent status to get the URL
    status_process = subprocess.run([
        "./zrok", "agent", "status"
    ], capture_output=True, text=True)

    print("Agent Status:")
    print(status_process.stdout)

    return process

# Start the tunnel
tunnel_process = start_zrok_tunnel(8000)
print("Zrok tunnel started! Check the agent status above for your public URL.")

In [None]:
!./zrok overview

In [None]:
import time

print("Server and zrok tunnel are running. Keeping the notebook alive...")

try:
    while True:
        time.sleep(60)
except KeyboardInterrupt:
    print("Shutting down.")