# Colab GPU Polling Worker
Minimal notebook that installs Ollama locally (optional) and runs a background loop that polls your stack’s `/gpu-jobs` queue, executes work on the Colab GPU, and posts results back.

> Set these env vars before running:
> * `STACK_API_BASE` – HTTPS base URL of your stack (e.g., `https://abc123.ngrok-free.app`)
> * `STACK_WORKER_TOKEN` – must match `GPU_WORKER_TOKEN` in `.env`
> * Optional: `WORKER_ID`, `POLL_INTERVAL`, `OLLAMA_URL`, `START_OLLAMA`


In [None]:
!nvidia-smi
!pip install -q fastapi uvicorn python-multipart requests diffusers accelerate transformers safetensors pillow
!curl -fsSL https://ollama.com/install.sh | sh

import base64
import io
import os
import subprocess
import sys
import textwrap
import threading
import time
from pathlib import Path

import requests
import torch
from diffusers import AutoPipelineForText2Image

# Optional inline overrides so reruns can tweak values without restarting the runtime.
# Leave a value blank ("") to keep the environment version, or fill it in to override.
ENV_OVERRIDES = {
    "STACK_API_BASE": "",
    "STACK_WORKER_TOKEN": "",
    "WORKER_ID": "",
    "POLL_INTERVAL": "",
    "OLLAMA_PRELOAD_MODELS": "",
    "IMAGEN_MODEL_ID": "",
}
for _key, _value in ENV_OVERRIDES.items():
    if _value:
        os.environ[_key] = _value

STACK_API_BASE = os.environ.get("STACK_API_BASE")
STACK_WORKER_TOKEN = os.environ.get("STACK_WORKER_TOKEN")
WORKER_ID = os.environ.get("WORKER_ID", "colab-worker")
POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", "5"))
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
SERVICE_AUTH_TOKEN = os.environ.get("SERVICE_AUTH_TOKEN", "colab-shared-12345")
START_OLLAMA = os.environ.get("START_OLLAMA", "1") == "1"
TUNNEL_PORT = int(os.environ.get("TUNNEL_PORT", "8000"))  # unused but kept for backwards compat
OLLAMA_PRELOAD_MODELS = [m.strip() for m in os.environ.get("OLLAMA_PRELOAD_MODELS", "llama3:8b").split(",") if m.strip()]
IMAGEN_MODEL_ID = os.environ.get("IMAGEN_MODEL_ID", "runwayml/stable-diffusion-v1-5")
IMAGEN_DEFAULT_STEPS = int(os.environ.get("IMAGEN_STEPS", "25"))
IMAGEN_MAX_STEPS = int(os.environ.get("IMAGEN_MAX_STEPS", "50"))
IMAGEN_DEFAULT_WIDTH = int(os.environ.get("IMAGEN_WIDTH", "512"))
IMAGEN_DEFAULT_HEIGHT = int(os.environ.get("IMAGEN_HEIGHT", "512"))
IMAGEN_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGEN_DTYPE = torch.float16 if IMAGEN_DEVICE == "cuda" else torch.float32

if not STACK_API_BASE or not STACK_WORKER_TOKEN:
    raise ValueError("STACK_API_BASE and STACK_WORKER_TOKEN must be set before running this cell.")

# Shut down any previous worker loop spawned by earlier runs of this cell
_prev_stop_event = globals().get("worker_stop_event")
_prev_worker_thread = globals().get("worker_thread")
if _prev_stop_event:
    _prev_stop_event.set()
if _prev_worker_thread and _prev_worker_thread.is_alive():
    _prev_worker_thread.join(timeout=5)

globals()["worker_stop_event"] = None
globals()["worker_thread"] = None

# Launch Ollama serve in the background if desired
ollama_proc = None
if START_OLLAMA:
    ollama_proc = subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    # Give the daemon a moment to accept pulls, then ensure requested models are available.
    if OLLAMA_PRELOAD_MODELS:
        time.sleep(2)
        for model_name in OLLAMA_PRELOAD_MODELS:
            print(f"Ensuring Ollama model '{model_name}' is available…")
            try:
                subprocess.run(["ollama", "pull", model_name], check=True)
            except subprocess.CalledProcessError as pull_exc:
                print(f"Warning: failed to pull {model_name}: {pull_exc}")

# Simple FastAPI bridge (local only) to expose health + tool endpoints if needed
app_py = textwrap.dedent(
    """
import os
import httpx
from fastapi import FastAPI, Header, HTTPException
from fastapi.responses import JSONResponse

SERVICE_AUTH_TOKEN = os.environ.get("SERVICE_AUTH_TOKEN", "colab-shared-12345")
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")

app = FastAPI(title="Colab GPU Local Service")

@app.get("/health")
def health():
    try:
        with httpx.Client(timeout=5.0) as client:
            resp = client.get(f"{OLLAMA_URL}/api/tags")
            resp.raise_for_status()
            models = resp.json().get("models", [])
    except Exception as exc:  # noqa: BLE001
        return JSONResponse({"status": "error", "detail": str(exc)})
    return {"status": "ok", "models": models}

@app.post("/tools/ollamaChat")
def run_ollama_chat(payload: dict, authorization: str = Header(default="")):
    expected = f"Bearer {SERVICE_AUTH_TOKEN}"
    if authorization != expected:
        raise HTTPException(status_code=401, detail="Unauthorized")
    with httpx.Client(timeout=120.0) as client:
        resp = client.post(f"{OLLAMA_URL}/api/generate", json=payload, stream=False)
        resp.raise_for_status()
        return resp.json()
"""
)

Path("app.py").write_text(app_py)

server = subprocess.Popen(
    [
        sys.executable,
        "-m",
        "uvicorn",
        "app:app",
        "--host",
        "0.0.0.0",
        f"--port={TUNNEL_PORT}",
    ],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
)

print("Local FastAPI server PID:", server.pid)
print("STACK_API_BASE:", STACK_API_BASE)
print("Worker ID:", WORKER_ID)


def image_to_base64(image):
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")


def clamp_steps(value, fallback, max_value):
    try:
        parsed = int(value)
    except (TypeError, ValueError):
        parsed = fallback
    parsed = max(5, min(parsed, max_value))
    return parsed


def multiple_of_eight(value, fallback):
    try:
        parsed = int(value)
    except (TypeError, ValueError):
        parsed = fallback
    parsed = max(256, min(parsed, 1024))
    parsed -= parsed % 8
    return parsed or fallback


diffusers_lock = threading.Lock()
diffusers_pipeline = None


def ensure_diffusers_pipeline():
    global diffusers_pipeline
    with diffusers_lock:
        if diffusers_pipeline is None:
            pipe = AutoPipelineForText2Image.from_pretrained(IMAGEN_MODEL_ID, torch_dtype=IMAGEN_DTYPE)
            pipe = pipe.to(IMAGEN_DEVICE)
            pipe.safety_checker = None
            diffusers_pipeline = pipe
    return diffusers_pipeline


def run_imagen_job(job_payload):
    body = job_payload.get("payload") or {}
    prompt = body.get("prompt")
    if not isinstance(prompt, str) or not prompt.strip():
        raise ValueError("prompt is required for imagenGenerate")

    negative_prompt = body.get("negative_prompt")
    guidance_scale = float(body.get("guidance_scale", 7.0))
    steps = clamp_steps(body.get("num_inference_steps"), IMAGEN_DEFAULT_STEPS, IMAGEN_MAX_STEPS)
    width = multiple_of_eight(body.get("width"), IMAGEN_DEFAULT_WIDTH)
    height = multiple_of_eight(body.get("height"), IMAGEN_DEFAULT_HEIGHT)
    seed = body.get("seed")

    pipeline = ensure_diffusers_pipeline()
    generator = None
    used_seed = None
    if seed is not None:
        try:
            used_seed = int(seed)
            generator = torch.Generator(device=IMAGEN_DEVICE).manual_seed(used_seed)
        except (TypeError, ValueError):
            used_seed = None

    started = time.time()
    result = pipeline(
        prompt=prompt.strip(),
        negative_prompt=negative_prompt.strip() if isinstance(negative_prompt, str) else None,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        width=width,
        height=height,
        generator=generator,
    )
    duration_ms = int((time.time() - started) * 1000)
    image = result.images[0]

    return {
        "image_base64": image_to_base64(image),
        "model": IMAGEN_MODEL_ID,
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "guidance_scale": guidance_scale,
        "num_inference_steps": steps,
        "width": width,
        "height": height,
        "seed": used_seed,
        "duration_ms": duration_ms,
        "device": IMAGEN_DEVICE,
    }


def run_gpu_job(job_payload):
    tool = job_payload.get("tool")
    body = job_payload.get("payload") or {}
    body.setdefault("stream", False)
    if tool == "ollamaChat":
        response = requests.post(f"{OLLAMA_URL}/api/generate", json=body, timeout=180)
        response.raise_for_status()
        return response.json()
    if tool == "imagenGenerate":
        return run_imagen_job(job_payload)
    raise ValueError(f"Unsupported tool: {tool}")


def start_worker_loop():
    base = STACK_API_BASE.rstrip("/")
    headers = {"Authorization": f"Bearer {STACK_WORKER_TOKEN}"}
    stop_event = threading.Event()

    def poll_forever():
        print(f"Worker loop started for {WORKER_ID}; polling {base} every {POLL_INTERVAL}s")
        while not stop_event.is_set():
            try:
                resp = requests.get(
                    f"{base}/gpu-jobs/next",
                    params={"worker": WORKER_ID},
                    headers=headers,
                    timeout=30,
                )
                if resp.status_code == 204:
                    time.sleep(POLL_INTERVAL)
                    continue
                resp.raise_for_status()
                job = resp.json().get("job")
                if not job:
                    time.sleep(POLL_INTERVAL)
                    continue
                job_id = job["id"]
                job_tool = job.get("tool", "unknown")
                print(f"[worker] Leased job {job_id} ({job_tool}) at {time.strftime('%H:%M:%S')}")
                try:
                    result = run_gpu_job(job)
                    requests.post(
                        f"{base}/gpu-jobs/{job_id}/complete",
                        headers=headers,
                        json={"status": "completed", "result": result},
                        timeout=30,
                    )
                    print(f"[worker] Completed job {job_id}")
                except Exception as job_exc:  # noqa: BLE001
                    print(f"[worker] Job failed {job_id}: {job_exc}")
                    requests.post(
                        f"{base}/gpu-jobs/{job_id}/complete",
                        headers=headers,
                        json={"status": "error", "detail": str(job_exc)},
                        timeout=30,
                    )
            except requests.HTTPError as http_exc:
                if http_exc.response is not None and http_exc.response.status_code == 401:
                    print("Worker auth failed; stopping loop.")
                    break
                print("Polling HTTP error:", http_exc)
                time.sleep(POLL_INTERVAL)
            except Exception as exc:  # noqa: BLE001
                print("Polling error:", exc)
                time.sleep(POLL_INTERVAL)
        print("Worker loop exiting…")

    thread = threading.Thread(target=poll_forever, name="gpu-worker", daemon=True)
    thread.start()

    globals()["worker_stop_event"] = stop_event
    globals()["worker_thread"] = thread


start_worker_loop()


## Usage
1. Set `STACK_API_BASE` and `STACK_WORKER_TOKEN` (match `.env` → `GPU_WORKER_TOKEN`).
2. Run the setup cell once; it installs Ollama (optional), starts a local FastAPI bridge, and begins polling `/gpu-jobs` on your stack.
3. Submit jobs via `POST /gpu-jobs` on the stack; the worker picks them up automatically and reports completion via `/gpu-jobs/{id}/complete`.
4. Use the cleanup cell to terminate the local server (optional if the runtime is about to reset).


In [None]:
from pyngrok import ngrok

if "public_tunnel" in globals():
    try:
        ngrok.disconnect(public_tunnel.public_url)
        print("ngrok tunnel closed")
    except Exception as exc:
        print("ngrok cleanup warning:", exc)
else:
    print("No tunnel to close")

if "server" in globals():
    try:
        server.terminate()
        print("Server stopped")
    except Exception as exc:
        print("Server cleanup warning:", exc)
else:
    print("No server process found")
