import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import uvicorn
from mcp.server.fastmcp import FastMCP
import logging


logger = logging.getLogger("MCP-test")
logger.setLevel(logging.DEBUG)

# Create lifespan context manager that includes startup connections
@asynccontextmanager
async def mcp_lifespan(app):
    logger.info("🚀 Starting MCP server lifespan...")
    logger.info("✅ MCP server lifespan started")
    try:
        # Run the MCP session manager (this runs indefinitely)
        async with mcp.session_manager.run():
            yield
    finally:
        logger.info("🛑 MCP server lifespan ending...")

# Create FastAPI app with FastMCP's lifespan to fix task group issue
logger.info("🔧 Creating MCP Proxy Server with shared lifespan...")


app = FastAPI(
    title="FastAPI server",
    description="A simple FastAPI server with MCP integration",
    lifespan=mcp_lifespan,
    redirect_slashes=False  # Prevent automatic redirects that lose Authorization headers
)

@app.middleware("http")
async def mcp_auth_middleware(request: Request, call_next):
    """MCP Authentication middleware per MCP Authorization specification"""
     # Determine if this is an MCP request
    is_mcp_request = request.url.path.startswith("/test/mcp") #or request.url.path == "/"
    
    if is_mcp_request:
        # Perform MCP-specific authentication
        logger.info("🔐 Authenticating MCP request...")
        await asyncio.sleep(1)  # Simulate async auth check
        logger.info("🔐 MCP request authenticated")
    else:
        logger.info("Non-MCP request, skipping MCP authentication")

    response = await call_next(request)

    return response

# Create MCP server
mcp = FastMCP(name="My MCP Server", debug=True, log_level='DEBUG', stateless_http=False)


@mcp.tool()
def hello() -> str:
    """A simple hello tool"""
    return "Hello from MCP!"


@app.get("/hello")
def read_hello():
    return {"message": "Hello from mcp-server-demo!"}

logger.info("🔧 Creating FastMCP streamable HTTP app...")
mcp_http_app = mcp.streamable_http_app()

logger.info("📌 Mounting MCP server with shared lifespan...")
app.mount("/test/", mcp_http_app, name="mcp_server")
logger.info(f"🔍 MCP HTTP app routes: {app.routes}")
logger.info("✅ MCP server mounted at /test/mcp")



if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
