In [None]:
import os, sys

p = "/kaggle/input/sage-zrok-token/.zrok_api_key"
zrok_token = None

if os.path.isfile(p):
    with open(p, "r", encoding="utf-8", errors="ignore") as f:
        zrok_token = f.read().strip()

if not zrok_token:
    print("❌ Token not found or empty:", p)
    sys.exit(1)

In [None]:
import os
import shutil

print("Setting up models...")

# --- Copy Animagine-XL model if not already present ---
source = "/kaggle/input/animagine-xl-3-1"
dest = "/kaggle/working/cagliostrolab/animagine-xl-3.1"

if os.path.exists(dest):
    print(f"✓ Animagine-XL already exists at {dest}, skipping copy")
else:
    print(f"  Copying Animagine-XL... (this may take a moment, it's large)")
    shutil.copytree(source, dest)
    print(f"  ✓ Copied to {dest}")

# --- Copy LCM-LoRA SDXL if not already present ---
lora_source = "/kaggle/input/lcm-lora-sdxl"
lora_dest = "/kaggle/working/latent-consistency/lcm-lora-sdxl"

if os.path.exists(lora_dest):
    print(f"✓ LCM-LoRA SDXL already exists at {lora_dest}, skipping copy")
else:
    print(f"  Copying LCM-LoRA SDXL...")
    shutil.copytree(lora_source, lora_dest)
    print(f"  ✓ Copied to {lora_dest}")

print("✅ All models ready!")

In [None]:
import os
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler

# --- Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Model / LoRA paths & IDs ---
MODEL_ID = "cagliostrolab/animagine-xl-3.1"
local_model_path = "/kaggle/working/cagliostrolab/animagine-xl-3.1"  # change to your working dir
#LORA_ID = "latent-consistency/lcm-lora-sdxl"
#local_lora_path = "/kaggle/working/latent-consistency/lcm-lora-sdxl"         # change to your working dir

# --- Load or download Stable Diffusion (DiffusionPipeline) ---
if os.path.exists(local_model_path):
    pipe_txt2img = DiffusionPipeline.from_pretrained(
        local_model_path,
        #variant="fp16",
        torch_dtype=torch.float16
    ).to(device)
else:
    pipe_txt2img = DiffusionPipeline.from_pretrained(
        MODEL_ID,
        #variant="fp16",
        torch_dtype=torch.float16
    ).to(device)
    # save local copy
    try:
        pipe_txt2img.save_pretrained(local_model_path)
    except Exception as e:
        print(f"Warning: failed to save pipeline locally to {local_model_path}: {e}")

# set scheduler
#pipe_txt2img.scheduler = DPMSolverMultistepScheduler.from_config(pipe_txt2img.scheduler.config)

from diffusers import DPMSolverMultistepScheduler

pipe_txt2img.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe_txt2img.scheduler.config,
    algorithm_type="sde-dpmsolver++"
)


# --- Load LCM-LoRA (prefer local copy) ---
#if os.path.exists(local_lora_path):
#    pipe_txt2img.load_lora_weights(local_lora_path)
#else:
#    pipe_txt2img.load_lora_weights(LORA_ID)

pipe_txt2img.safety_checker = lambda images, clip_input=None: (images, [False] * len(images))



# ensure pipeline on device
#pipe_txt2img = pipe_txt2img.to(device)

In [None]:
# Reusable text-to-image generator (returns PIL.Image for easy FastAPI integration)
from datetime import datetime
from IPython.display import display
import torch
import io
import os

def generate_txt2img(
    pipe_txt2img,
    prompt,
    negative_prompt=None,
    seed=None,
    num_inference_steps=40,
    guidance_scale=7.5,
    height=768,
    width=768,
    device=None
):
    """
    Generate an image with `pipe_txt2img` and return a PIL.Image (does not save by default).
    If save=True it will write a PNG to disk but still return the PIL.Image.
    If return_metadata=True it returns (image, metadata_dict).
    """
    # device detection
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # choose seed
    if seed is None:
        seed = int(torch.randint(0, 2**31 - 1, (1,)).item())

    # create generator for the chosen device (fallback to CPU if unsupported)
    try:
        gen = torch.Generator(device=device).manual_seed(seed)
    except Exception:
        gen = torch.Generator().manual_seed(seed)

    # build call kwargs
    call_kwargs = {
        "prompt": prompt,
        "num_inference_steps": int(num_inference_steps),
        "guidance_scale": float(guidance_scale),
        "generator": gen,
    }
    if negative_prompt is not None:
        call_kwargs["negative_prompt"] = negative_prompt
    if height is not None:
        call_kwargs["height"] = int(height)
    if width is not None:
        call_kwargs["width"] = int(width)


    # call the pipeline
    pipeline_output = pipe_txt2img(**call_kwargs)

    # extract image (typical diffusers output)
    try:
        image = pipeline_output.images[0]
    except Exception:
        image = pipeline_output[0]


    return image

In [None]:
# --------------------------
# FastAPI app init
# --------------------------

from fastapi import FastAPI
import nest_asyncio
import uvicorn

# Assume generate_txt2img, generate_img2img, pipe_txt2img, pipe_img2img are already defined
app = FastAPI()
nest_asyncio.apply()  # allow running uvicorn in Colab
#print("FastAPI app initialized.")

In [None]:

# --------------------------# FastAPI route to expose generate_txt2img# --------------------------
from fastapi import Form, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
import io
from datetime import datetime

# NOTE:
# - This cell expects `app`, `generate_txt2img`, and `pipe_txt2img` to already exist
#   (as in your retained cell). It *does not* re-create or modify `app`.
# - If you want to disable the safety checker globally, you could (in a separate
#   cell) do something like:
#   # pipe_txt2img.safety_checker = lambda images, clip_input=None: (images, [False] * len(images))
#   # keep that commented here so you can toggle it manually.
@app.post("/generate")
async def generate_endpoint(
    prompt: str = Form(...),
    negative_prompt: str = Form(None),
    seed: int = Form(None),
    num_inference_steps: int = Form(40),
    guidance_scale: float = Form(7.5),
    height: int = Form(768),
    width: int = Form(768)
):

    try:
        # call your reusable function (uses pipe_txt2img already available)
        image = generate_txt2img(
            pipe_txt2img,
            prompt=prompt,
            negative_prompt=negative_prompt,
            seed=seed,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width
        )
        # Convert PIL.Image to in-memory PNG
        #buf = io.BytesIO()
        #headers = {"Content-Disposition": f'inline; filename="{filename}"'}
        #return StreamingResponse(buf, media_type="image/png", headers=headers)

        # --- Return PNG stream ---
        buffer = io.BytesIO()
        image.save(buffer, format="PNG")
        buffer.seek(0)
        return StreamingResponse(buffer, media_type="image/png")


    
    except Exception as e:
        # return a JSON error so the client can debug
        return JSONResponse({"error": str(e)}, status_code=500)

In [None]:
# Download zrok v1.1.3 (latest)
!wget https://github.com/openziti/zrok/releases/download/v1.1.3/zrok_1.1.3_linux_amd64.tar.gz
!tar -xzf zrok_1.1.3_linux_amd64.tar.gz
!chmod +x zrok

In [None]:
# Enable (automatic migration from 0.4)
!./zrok enable --headless "$zrok_token"

# Use the agent for better process management
#!./zrok agent start &
#!./zrok share public localhost:8000 --headless

In [None]:
#!./zrok disable

In [None]:
import uvicorn
import threading

def run_uvicorn():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

# Start in background thread
threading.Thread(target=run_uvicorn, daemon=True).start()

In [None]:
import subprocess
import re
import time

def start_zrok_tunnel(port=8000):
    # Start the tunnel
    process = subprocess.Popen([
        "./zrok", "share", "public", f"localhost:{port}", "--headless"
    ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Give it a moment to start
    time.sleep(3)

    # Check agent status to get the URL
    status_process = subprocess.run([
        "./zrok", "agent", "status"
    ], capture_output=True, text=True)

    print("Agent Status:")
    print(status_process.stdout)

    return process

# Start the tunnel
tunnel_process = start_zrok_tunnel(8000)
print("Zrok tunnel started! Check the agent status above for your public URL.")

In [None]:
!./zrok overview

In [None]:
import time

print("Server and zrok tunnel are running. Keeping the notebook alive...")

try:
    while True:
        time.sleep(60)
except KeyboardInterrupt:
    print("Shutting down.")