In [None]:
!pip install fastapi uvicorn pyngrok python-multipart nest-asyncio diffusers transformers accelerate

import torch
import numpy as np
import io
import base64
import nest_asyncio
from typing import Optional
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline
from PIL import Image
from pyngrok import ngrok, conf
import uvicorn
import os
import time

nest_asyncio.apply()

HF_TOKEN = ""
NGROK_AUTH_TOKEN = "" # <--- PASTE YOUR NGROK TOKEN HERE

# Set ngrok token
if NGROK_AUTH_TOKEN != "YOUR_NGROK_AUTHTOKEN_HERE":
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
else:
    print("WARNING: Ngrok auth token not set. The tunnel might expire quickly.")

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

print(f"Loading Model to {device}...")

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    variant="fp16" if device == "cuda" else None,
    token=HF_TOKEN,
    safety_checker=None,
    requires_safety_checker=False
)
pipe = pipe.to(device)
print("Model Loaded.")


def apply_watermark_logic(image: Image.Image, target_radius=40, strength=2.5):
    """Applies watermark and returns the PIL Image."""
    img_ycbcr = image.convert("YCbCr")
    y, cb, cr = img_ycbcr.split()

    y_arr = np.array(y)
    f = np.fft.fft2(y_arr)
    fshift = np.fft.fftshift(f)

    rows, cols = y_arr.shape
    crow, ccol = rows // 2, cols // 2
    y_grid, x_grid = np.ogrid[:rows, :cols]
    mask_area = (x_grid - ccol)**2 + (y_grid - crow)**2

    r_min, r_max = target_radius - 1, target_radius + 1
    mask = (mask_area >= r_min**2) & (mask_area <= r_max**2)

    fshift[mask] *= strength

    f_ishift = np.fft.ifftshift(fshift)
    img_back = np.fft.ifft2(f_ishift)
    img_back = np.abs(img_back)

    img_back = np.clip(img_back, 0, 255).astype(np.uint8)
    watermarked_y = Image.fromarray(img_back)

    final_image = Image.merge("YCbCr", (watermarked_y, cb, cr))
    return final_image.convert("RGB")

def get_radial_profile(fft_shift, center, max_r):
    y, x = np.indices(fft_shift.shape)
    r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
    r = r.astype(int)

    # Handle division by zero or empty bins safely
    tbin = np.bincount(r.ravel(), fft_shift.ravel())
    nr = np.bincount(r.ravel())

    # Avoid division by zero
    radialprofile = np.zeros_like(tbin, dtype=float)
    nonzero = nr > 0
    radialprofile[nonzero] = tbin[nonzero] / nr[nonzero]

    return radialprofile[:max_r]

def detect_logic(image: Image.Image, target_radius=40):
    """Returns z-score and detection boolean."""
    img_gray = image.convert("L")
    img_arr = np.array(img_gray)

    f = np.fft.fft2(img_arr)
    fshift = np.fft.fftshift(f)
    magnitude = np.abs(fshift)

    center = (magnitude.shape[0] // 2, magnitude.shape[1] // 2)
    profile = get_radial_profile(magnitude, center, max_r=target_radius + 20)

    signal = profile[target_radius]

    neighbors = np.concatenate([
        profile[target_radius-5 : target_radius-2],
        profile[target_radius+2 : target_radius+5]
    ])
    noise_mean = np.mean(neighbors)
    noise_std = np.std(neighbors)

    # Avoid division by zero
    z_score = (signal - noise_mean) / (noise_std + 1e-8)

    return float(z_score)

def image_to_base64(image: Image.Image) -> str:
    """Converts PIL Image to base64 string."""
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


app = FastAPI(title="Diffusion Watermarker API")

# --- Pydantic Models ---
class PromptRequest(BaseModel):
    prompt: str
    strength: float = 2.5
    target_radius: int = 40

class GenerationResponse(BaseModel):
    clean_image_base64: str
    watermarked_image_base64: str
    watermark_strength: float
    target_radius: int

class HealthResponse(BaseModel):
    status: str
    gpu_available: bool
    gpu_name: Optional[str]

class DetectionResponse(BaseModel):
    z_score: float
    is_generated: bool
    confidence: str

# --- Endpoints ---

@app.get("/health", response_model=HealthResponse)
def health_check():
    """Checks server health and GPU status."""
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
    return {
        "status": "online",
        "gpu_available": torch.cuda.is_available(),
        "gpu_name": gpu_name
    }

@app.post("/prompt", response_model=GenerationResponse)
def generate_image(request: PromptRequest):
    """Generates a clean image, applies watermark, returns both."""
    try:
        print(f"Generating: {request.prompt}")
        neg_prompt = "blurry, low quality, deformed, ugly, bad anatomy, jpeg artifacts"

        clean_img = pipe(
            request.prompt,
            negative_prompt=neg_prompt,
            num_inference_steps=30,
            guidance_scale=7.5
        ).images[0]

        wm_img = apply_watermark_logic(
            clean_img,
            target_radius=request.target_radius,
            strength=request.strength
        )

        return {
            "clean_image_base64": image_to_base64(clean_img),
            "watermarked_image_base64": image_to_base64(wm_img),
            "watermark_strength": request.strength,
            "target_radius": request.target_radius
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/detect", response_model=DetectionResponse)
async def detect_image(file: UploadFile = File(...)):
    """Upload an image to detect if it was generated/watermarked."""
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))

        # Run detection (default radius 40 based on your notebook)
        z_score = detect_logic(image, target_radius=40)

        # Determine Threshold (Using 3.0 as standard deviation threshold)
        is_detected = z_score > 3.0

        # Simple confidence logic
        if z_score > 6.0: confidence = "Very High"
        elif z_score > 3.0: confidence = "High"
        elif z_score > 2.0: confidence = "Suspect"
        else: confidence = "Low"

        return {
            "z_score": z_score,
            "is_generated": is_detected,
            "confidence": confidence
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


print("Cleaning up old tunnels...")
ngrok.kill()
os.system("pkill ngrok")
time.sleep(2) # Give it a moment to release the port

try:
    public_url = ngrok.connect(8000).public_url
    print(f"\n🚀 SERVER IS LIVE! Access Swagger UI here: {public_url}/docs\n")
except Exception as e:
    print(f"Ngrok connection failed: {e}")
    print("Trying again with default config...")
    # Fallback in case of config issues
    conf.get_default().region = "us"
    public_url = ngrok.connect(8000).public_url
    print(f"\n🚀 SERVER IS LIVE! Access Swagger UI here: {public_url}/docs\n")

config = uvicorn.Config(app, port=8000, log_level="info")
server = uvicorn.Server(config)
await server.serve()
