In [None]:
import os, sys

p = "/kaggle/input/sage-zrok-token/.zrok_api_key"
zrok_token = None

if os.path.isfile(p):
    with open(p, "r", encoding="utf-8", errors="ignore") as f:
        zrok_token = f.read().strip()

if not zrok_token:
    print("❌ Token not found or empty:", p)
    sys.exit(1)

In [None]:
import os
import shutil

print("Setting up model...")

# --- Copy LCM SoteMix if not already present ---
source = "/kaggle/input/lcm-sotemix"
dest = "/kaggle/working/Disty0/LCM_SoteMix"

if os.path.exists(dest):
    print(f"✓ LCM SoteMix already exists at {dest}, skipping copy")
else:
    # Ensure the parent directory of the destination exists
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    print(f"  Copying LCM SoteMix... (this may take a moment)")
    shutil.copytree(source, dest)
    print(f"  ✓ Copied to {dest}")

print("✅ Model ready!")


In [None]:
# ----------------------------
# TXT2IMG + LCM pipeline
# ----------------------------
import torch
from diffusers import AutoPipelineForText2Image, LCMScheduler

# --- Configuration ---
# The model will be loaded from a local Kaggle dataset path.
# This path should correspond to the dataset you create containing the 'Disty0/LCM_SoteMix' model files.
MODEL_PATH = '/kaggle/working/Disty0/LCM_SoteMix'
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading Text-to-Image pipeline from local path: {MODEL_PATH}")

# --- Load the pipeline from the local dataset path ---
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16
)

# --- Set the LCM Scheduler ---
pipe_txt2img.scheduler = LCMScheduler.from_config(pipe_txt2img.scheduler.config)

# --- Move to device and disable safety checker ---
pipe_txt2img = pipe_txt2img.to(device)
pipe_txt2img.safety_checker = lambda images, clip_input=None: (images, [False] * len(images))

print("✅ Text-to-Image pipeline is ready!")

In [None]:
# ----------------------
# IMG2IMG + LCM pipeline
# ----------------------
import torch
from diffusers import AutoPipelineForImage2Image, LCMScheduler

# --- Configuration ---
# The model will be loaded from a local Kaggle dataset path.
# This path should correspond to the dataset you create containing the 'Disty0/LCM_SoteMix' model files.
MODEL_PATH = '/kaggle/working/Disty0/LCM_SoteMix'
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading Image-to-Image pipeline from local path: {MODEL_PATH}")

# --- Load the pipeline from the local dataset path ---
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16
)

# --- Set the LCM Scheduler ---
pipe_img2img.scheduler = LCMScheduler.from_config(pipe_img2img.scheduler.config)

# --- Move to device and disable safety checker ---
pipe_img2img = pipe_img2img.to(device)
pipe_img2img.safety_checker = lambda images, clip_input=None: (images, [False] * len(images))

print("✅ Image-to-Image pipeline is ready!")

In [None]:

# -----------------------------
# METHOD text-to-image generator
# (returns PIL.Image for easy FastAPI integration)
# -----------------------------

from datetime import datetime
from IPython.display import display
import torch
import io
import os

def generate_txt2img(
    pipe_txt2img,
    prompt,
    negative_prompt=None,
    seed=None,
    num_inference_steps=4,
    guidance_scale=1.0,
    height=768,
    width=768,
    device=None
):
    """
    Generate an image with `pipe_txt2img` and return a PIL.Image (does not save by default).
    If save=True it will write a PNG to disk but still return the PIL.Image.
    If return_metadata=True it returns (image, metadata_dict).
    """
    # device detection
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # choose seed
    if seed is None:
        seed = int(torch.randint(0, 2**31 - 1, (1,)).item())

    # create generator for the chosen device (fallback to CPU if unsupported)
    try:
        gen = torch.Generator(device=device).manual_seed(seed)
    except Exception:
        gen = torch.Generator().manual_seed(seed)

    # build call kwargs
    call_kwargs = {
        "prompt": prompt,
        "num_inference_steps": int(num_inference_steps),
        "guidance_scale": float(guidance_scale),
        "generator": gen,
    }
    if negative_prompt is not None:
        call_kwargs["negative_prompt"] = negative_prompt
    if height is not None:
        call_kwargs["height"] = int(height)
    if width is not None:
        call_kwargs["width"] = int(width)


    # call the pipeline
    pipeline_output = pipe_txt2img(**call_kwargs)

    # extract image (typical diffusers output)
    try:
        image = pipeline_output.images[0]
    except Exception:
        image = pipeline_output[0]


    return image

In [None]:
# ------------------------------------- #
# Reusable img2img generator (SDXL + LCM)
# ------------------------------------- #
import torch
import random

def generate_img2img(
    pipe_img2img,
    prompt,
    init_image,
    negative_prompt=None,
    strength=0.8,
    guidance_scale=1.0,
    num_inference_steps=4,
    seed=None,
    dtype=torch.float16,
    device=None
):
    """
    Generate an image with `pipe_img2img` and return a PIL.Image (does not save by default).
    Designed for SDXL + LCM LoRA pipelines.
    """

    # --- Device detection ---
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # --- Seed handling ---
    if seed is None:
        seed = int(torch.randint(0, 2**31 - 1, (1,)).item())

    try:
        generator = torch.Generator(device=device).manual_seed(seed)
    except Exception:
        generator = torch.Generator().manual_seed(seed)

    # --- Build call kwargs ---
    call_kwargs = {
        "prompt": prompt,
        "image": init_image,
        "strength": float(strength),
        "guidance_scale": float(guidance_scale),
        "num_inference_steps": int(num_inference_steps),
        "generator": generator,
    }
    if negative_prompt is not None:
        call_kwargs["negative_prompt"] = negative_prompt

    # --- Run pipeline ---
    with torch.inference_mode(), torch.autocast(device, dtype=dtype):
        pipeline_output = pipe_img2img(**call_kwargs)

    # --- Extract result ---
    try:
        image = pipeline_output.images[0]
    except Exception:
        image = pipeline_output[0]

    return image

In [None]:
# --------------------------
# FastAPI app init
# --------------------------

from fastapi import FastAPI
import nest_asyncio
import uvicorn

# Assume generate_txt2img, generate_img2img, pipe_txt2img, pipe_img2img are already defined
app = FastAPI()
nest_asyncio.apply()  # allow running uvicorn in Colab
#print("FastAPI app initialized.")

In [None]:


# ---------------------------
# Unified T2I/I2I endpoint /generate (SDXL + LCM LoRA)
# ---------------------------

from fastapi import Form, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
from io import BytesIO
from PIL import Image

@app.post("/generate")
async def generate(
    prompt: str = Form(...),
    negative_prompt: str = Form(None),
    height: int = Form(768),
    width: int = Form(768),
    num_inference_steps: int = Form(4),       # LCM default low steps
    guidance_scale: float = Form(1.0),        # LCM default guidance
    seed: int = Form(None),

    # img2img-only params
    strength: float = Form(0.8),
    init_image: UploadFile = File(None),      # optional: if present → img2img
):
    """Unified endpoint: txt2img if no init_image, img2img if init_image provided."""

    # --- Img2img path ---
    if init_image is not None:
        try:
            init_img = Image.open(BytesIO(await init_image.read())).convert("RGB")
        except Exception as e:
            return JSONResponse(
                status_code=400,
                content={"error": f"Failed to read init_image: {str(e)}"},
            )

        try:
            generated_image = generate_img2img(
                pipe_img2img,
                prompt=prompt,
                negative_prompt=negative_prompt,
                init_image=init_img,
                strength=strength,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                seed=seed,
            )
        except Exception as e:
            return JSONResponse(status_code=500, content={"error": str(e)})

    # --- Txt2img path ---
    else:
        try:
            generated_image = generate_txt2img(
                pipe_txt2img,
                prompt=prompt,
                negative_prompt=negative_prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                seed=seed,
            )
        except Exception as e:
            return JSONResponse(status_code=500, content={"error": str(e)})

    # --- Return PNG stream ---
    buffer = BytesIO()
    generated_image.save(buffer, format="PNG")
    buffer.seek(0)
    return StreamingResponse(buffer, media_type="image/png")

In [None]:
# Download zrok v1.1.3 (latest)
!wget https://github.com/openziti/zrok/releases/download/v1.1.3/zrok_1.1.3_linux_amd64.tar.gz
!tar -xzf zrok_1.1.3_linux_amd64.tar.gz
!chmod +x zrok

In [None]:
# Enable (automatic migration from 0.4)
!./zrok enable --headless "$zrok_token"

# Use the agent for better process management
#!./zrok agent start &
#!./zrok share public localhost:8000 --headless

In [None]:
#!./zrok disable

In [None]:
import uvicorn
import threading

def run_uvicorn():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

# Start in background thread
threading.Thread(target=run_uvicorn, daemon=True).start()

In [None]:
import subprocess
import re
import time

def start_zrok_tunnel(port=8000):
    # Start the tunnel
    process = subprocess.Popen([
        "./zrok", "share", "public", f"localhost:{port}", "--headless"
    ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Give it a moment to start
    time.sleep(3)

    # Check agent status to get the URL
    status_process = subprocess.run([
        "./zrok", "agent", "status"
    ], capture_output=True, text=True)

    print("Agent Status:")
    print(status_process.stdout)

    return process

# Start the tunnel
tunnel_process = start_zrok_tunnel(8000)
print("Zrok tunnel started! Check the agent status above for your public URL.")

In [None]:
!./zrok overview

In [None]:
import time

print("Server and zrok tunnel are running. Keeping the notebook alive...")

try:
    while True:
        time.sleep(60)
except KeyboardInterrupt:
    print("Shutting down.")