In [None]:
!pip install -q fastapi uvicorn torch diffusers transformers accelerate pyngrok python-multipart pydantic

In [None]:
# Cell for importing dependencies
import torch
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field, validator
from typing import Optional, List
from diffusers import StableDiffusionPipeline
import base64
from io import BytesIO
import time
from collections import defaultdict
import re
import os
import uvicorn
from datetime import datetime
import threading
import nest_asyncio
from pyngrok import ngrok
import queue
import logging

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Create a queue for handling requests
request_queue = queue.Queue()

In [None]:
# Initialize FastAPI app
app = FastAPI(
    title="Stable Diffusion API",
    description="Image generation API running on Google Colab GPU",
    version="1.0.0"
)

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Security
API_KEY = "your-secret-key-here"  # Change this
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True)

# Rate limiting setup
RATE_LIMIT_MINUTES = 5
MAX_REQUESTS_PER_WINDOW = 10
request_counts = defaultdict(list)

In [None]:
class GenerationRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=500)
    negative_prompt: Optional[str] = Field(default="", max_length=500)
    guidance_scale: float = Field(default=7.5, ge=1.0, le=20.0)
    num_inference_steps: int = Field(default=30, ge=1, le=100)

    @validator('prompt')
    def validate_prompt(cls, v):
        v = re.sub(r'\s+', ' ', v).strip()
        if not v:
            raise ValueError("Prompt cannot be empty")
        return v

class GenerationResponse(BaseModel):
    status: str
    message: Optional[str] = None
    image: Optional[str] = None
    generated_at: datetime


<ipython-input-19-9956405ae0de>:7: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  @validator('prompt')


In [None]:
# Pipeline initialization
def initialize_pipeline():
    pipeline = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16,
        safety_checker=None
    )

    if torch.cuda.is_available():
        pipeline = pipeline.to("cuda")
        logger.info("Using GPU for inference")
    else:
        logger.warning("GPU not available, using CPU")

    pipeline.enable_attention_slicing()
    return pipeline


In [None]:
# Worker thread function
def worker_thread(pipeline):
    while True:
        try:
            # Get request from queue
            request_data = request_queue.get()
            if request_data is None:
                break

            prompt, negative_prompt, guidance_scale, num_inference_steps, result_queue = request_data

            # Generate image
            with torch.inference_mode():
                image = pipeline(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    num_images_per_prompt=1,
                ).images[0]

            # Convert to base64
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode()

            # Put result in result queue
            result_queue.put({
                "status": "success",
                "image": img_str,
                "generated_at": datetime.utcnow()
            })

        except Exception as e:
            logger.error(f"Error in worker thread: {str(e)}")
            if result_queue:
                result_queue.put({
                    "status": "error",
                    "message": str(e),
                    "generated_at": datetime.utcnow()
                })
        finally:
            request_queue.task_done()

async def verify_api_key(api_key: str = Depends(api_key_header)):
    if api_key != API_KEY:
        raise HTTPException(
            status_code=401,
            detail="Invalid API key"
        )
    return api_key

In [None]:
@app.post("/generate/", response_model=GenerationResponse)
async def generate_image(
    request: GenerationRequest,
    background_tasks: BackgroundTasks,
    api_key: str = Depends(verify_api_key)
):
    try:
        # Create result queue for this request
        result_queue = queue.Queue()

        # Put request in queue
        request_queue.put((
            request.prompt,
            request.negative_prompt,
            request.guidance_scale,
            request.num_inference_steps,
            result_queue
        ))

        # Wait for result
        result = result_queue.get(timeout=60)
        return result

    except queue.Empty:
        raise HTTPException(
            status_code=504,
            detail="Request timed out"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=str(e)
        )

In [None]:
@app.get("/health/")
async def health_check():
    return {
        "status": "healthy",
        "gpu_available": torch.cuda.is_available(),
        "timestamp": datetime.utcnow()
    }


In [None]:
# Cell for starting the server
def start_server():
    # Initialize pipeline
    pipeline = initialize_pipeline()

    # Start worker thread
    worker = threading.Thread(
        target=worker_thread,
        args=(pipeline,),
        daemon=True
    )
    worker.start()

    # Start ngrok
    ngrok_tunnel = ngrok.connect(8000)
    logger.info(f'Public URL: {ngrok_tunnel.public_url}')
    print(f'Public URL: {ngrok_tunnel.public_url}')
    # Enable notebook asyncio support
    nest_asyncio.apply()

    # Start server
    uvicorn.run(app, host="0.0.0.0", port=8000)


In [None]:
!ngrok config add-authtoken 2pqAryFpOn6pt3y4F8by2rV7eVl_HnmvLCipjgjzuxMiRCwb

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [None]:
# Start the server
start_server()

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

ERROR:asyncio:Task exception was never retrieved
future: <Task finished name='Task-1' coro=<Server.serve() done, defined at /usr/local/lib/python3.10/dist-packages/uvicorn/server.py:67> exception=KeyboardInterrupt()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/main.py", line 579, in run
    server.run()
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/server.py", line 65, in run
    return asyncio.run(self.serve(sockets=sockets))
  File "/usr/local/lib/python3.10/dist-packages/nest_asyncio.py", line 30, in run
    return loop.run_until_complete(task)
  File "/usr/local/lib/python3.10/dist-packages/nest_asyncio.py", line 92, in run_until_complete
    self._run_once()
  File "/usr/local/lib/python3.10/dist-packages/nest_asyncio.py", line 133, in _run_once
    handle._run()
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/lib/python3.10/asyncio/tasks.py", l

Public URL: https://99c3-35-185-172-23.ngrok-free.app


INFO:     Started server process [591]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO:     105.108.176.68:0 - "GET / HTTP/1.1" 404 Not Found


  0%|          | 0/30 [00:00<?, ?it/s]

INFO:     105.108.176.68:0 - "POST /generate/ HTTP/1.1" 200 OK


  0%|          | 0/30 [00:00<?, ?it/s]

INFO:     105.108.176.68:0 - "POST /generate/ HTTP/1.1" 200 OK


  0%|          | 0/30 [00:00<?, ?it/s]

INFO:     105.108.176.68:0 - "POST /generate/ HTTP/1.1" 200 OK
INFO:     105.108.176.68:0 - "POST /generate/ HTTP/1.1" 422 Unprocessable Entity


Token indices sequence length is longer than the specified maximum sequence length for this model (88 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['masterpiece , evoking intense drama and historical grandeur .']


  0%|          | 0/30 [00:00<?, ?it/s]

INFO:     105.108.176.68:0 - "POST /generate/ HTTP/1.1" 200 OK


# TEST FUNCTION OUTSIDE THE SERVER

In [None]:
# Example usage cell
import requests
from IPython.display import Image as IPImage
from PIL import Image

def test_api(
    prompt,
    api_url,
    negative_prompt="",
    guidance_scale=7.5,
    num_inference_steps=30
):
    headers = {
        "X-API-Key": API_KEY,
        "Content-Type": "application/json"
    }

    data = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "guidance_scale": guidance_scale,
        "num_inference_steps": num_inference_steps
    }

    response = requests.post(f"{api_url}/generate/", headers=headers, json=data)
    response.raise_for_status()

    result = response.json()
    if result["status"] == "success":
        image_data = base64.b64decode(result["image"])
        return IPImage(image_data)
    else:
        raise Exception(result.get("message", "Unknown error"))

In [None]:
# Test the API
# Get the ngrok URL from the server output and replace it here
NGROK_URL = "https://your-ngrok-url"

# Test image generation
result = test_api(
    prompt="a beautiful sunset over the ocean, hyperrealistic, 8k",
    api_url=NGROK_URL
)
display(result)