diff --git a/.changeset/gentle-cycles-change.md b/.changeset/gentle-cycles-change.md new file mode 100644 index 0000000..955e9db --- /dev/null +++ b/.changeset/gentle-cycles-change.md @@ -0,0 +1,5 @@ +--- +"mcp-handler": patch +--- + +Fix memory leak from SSE streams not closing properly diff --git a/src/handler/mcp-api-handler.ts b/src/handler/mcp-api-handler.ts index 78e7805..2e708a6 100644 --- a/src/handler/mcp-api-handler.ts +++ b/src/handler/mcp-api-handler.ts @@ -22,6 +22,17 @@ import { EventEmittingResponse } from "../lib/event-emitter.js"; import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types"; import { getAuthContext } from "../auth/auth-context"; import { ServerOptions } from "."; +import { + createSessionManager, + subscribeWithTracking, + unsubscribeWithTracking, + checkConnectionLimit, + getConnectionStats, + setMaxConnections, + cleanupStaleSessions, + getSession, + type SessionManager +} from "../lib/session-manager"; interface SerializedRequest { requestId: string; @@ -125,6 +136,31 @@ export type Config = { * @default false */ disableSse?: boolean; + + /** + * Maximum number of concurrent connections allowed. + * Helps prevent memory leaks from too many open connections. + * @default 100 + */ + maxConnections?: number; + + /** + * Maximum age of a session in milliseconds before automatic cleanup. + * @default 3600000 (1 hour) + */ + maxSessionAge?: number; + + /** + * Timeout for individual requests in milliseconds. + * @default 30000 (30 seconds) + */ + requestTimeout?: number; + + /** + * If true, enables enhanced memory management and connection monitoring. + * @default true + */ + enableMemoryManagement?: boolean; }; /** @@ -226,6 +262,10 @@ export function initializeMcpApiHandler( maxDuration: 60, verboseLogs: false, disableSse: false, + maxConnections: 100, + maxSessionAge: 60 * 60 * 1000, // 1 hour + requestTimeout: 30 * 1000, // 30 seconds + enableMemoryManagement: true, } ) { const { @@ -237,6 +277,10 @@ export function initializeMcpApiHandler( maxDuration, verboseLogs, disableSse, + maxConnections, + maxSessionAge, + requestTimeout, + enableMemoryManagement, } = config; const { @@ -258,6 +302,30 @@ export function initializeMcpApiHandler( const logger = createLogger(verboseLogs); + // Initialize memory management + if (enableMemoryManagement && maxConnections) { + setMaxConnections(maxConnections); + } + + // Start periodic cleanup of stale sessions + if (enableMemoryManagement && maxSessionAge) { + const cleanupInterval = setInterval(async () => { + try { + const cleaned = await cleanupStaleSessions(maxSessionAge); + if (cleaned > 0) { + logger.log(`Cleaned up ${cleaned} stale sessions`); + } + } catch (error) { + logger.error('Error during stale session cleanup:', error); + } + }, Math.min(maxSessionAge / 4, 15 * 60 * 1000)); // Clean every 15 minutes or 1/4 of max age + + // Clear interval on process exit + process.on('exit', () => clearInterval(cleanupInterval)); + process.on('SIGINT', () => clearInterval(cleanupInterval)); + process.on('SIGTERM', () => clearInterval(cleanupInterval)); + } + let servers: McpServer[] = []; let statelessServer: McpServer; @@ -266,6 +334,20 @@ export function initializeMcpApiHandler( }); return async function mcpApiHandler(req: Request, res: ServerResponse) { + // Early connection limit check to prevent memory exhaustion + if (enableMemoryManagement && !checkConnectionLimit()) { + res.statusCode = 503; + res.end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Service temporarily unavailable: connection limit reached" + }, + id: null + })); + return; + } + const url = new URL(req.url || "", "https://example.com"); if (url.pathname === streamableHttpEndpoint) { if (req.method === "GET") { @@ -403,6 +485,12 @@ export function initializeMcpApiHandler( const transport = new SSEServerTransport(sseMessageEndpoint, res); const sessionId = transport.sessionId; + // Create session manager for enhanced cleanup + let sessionManager: SessionManager | null = null; + if (enableMemoryManagement) { + sessionManager = createSessionManager(sessionId, redis, logger); + } + const eventRes = new EventEmittingResponse( createFakeIncomingMessage(), config.onEvent, @@ -421,10 +509,15 @@ export function initializeMcpApiHandler( servers.push(server); - server.server.onclose = () => { + server.server.onclose = async () => { logger.log("SSE connection closed"); eventRes.endSession("SSE"); servers = servers.filter((s) => s !== server); + + // Enhanced cleanup using session manager + if (sessionManager) { + await sessionManager.cleanup(); + } }; let logs: { @@ -528,8 +621,13 @@ export function initializeMcpApiHandler( logs = []; }, 100); - await redis.subscribe(`requests:${sessionId}`, handleMessage); - logger.log(`Subscribed to requests:${sessionId}`); + // Subscribe with tracking for enhanced cleanup + if (sessionManager) { + await subscribeWithTracking(redis, sessionManager, `requests:${sessionId}`, handleMessage, logger); + } else { + await redis.subscribe(`requests:${sessionId}`, handleMessage); + logger.log(`Subscribed to requests:${sessionId}`); + } let timeout: NodeJS.Timeout; let resolveTimeout: (value: unknown) => void; @@ -540,11 +638,19 @@ export function initializeMcpApiHandler( }, (maxDuration ?? 60) * 1000); }); - // eslint-disable-next-line no-inner-declarations + // Enhanced cleanup function async function cleanup() { clearTimeout(timeout); clearInterval(interval); - await redis.unsubscribe(`requests:${sessionId}`, handleMessage); + + if (sessionManager) { + // Session manager handles all subscription cleanup + await sessionManager.cleanup(); + } else { + // Fallback to original cleanup + await redis.unsubscribe(`requests:${sessionId}`, handleMessage); + } + logger.log("Done"); res.statusCode = 200; res.end(); @@ -593,9 +699,11 @@ export function initializeMcpApiHandler( headers: Object.fromEntries(req.headers.entries()), }; - // Declare timeout and response handling state before subscription + // Enhanced memory management for SSE message endpoint + let sessionManager: SessionManager | null = null; let timeout: NodeJS.Timeout; let hasResponded = false; + // Safe response handler to prevent double res.end() const sendResponse = (status: number, body: string) => { if (!hasResponded) { @@ -606,10 +714,35 @@ export function initializeMcpApiHandler( } }; - // Handles responses from the /sse endpoint. - await redis.subscribe( - `responses:${sessionId}:${requestId}`, - (message) => { + // Enhanced cleanup function for this request + const cleanup = async () => { + if (sessionManager) { + await unsubscribeWithTracking(redis, sessionManager, `responses:${sessionId}:${requestId}`, logger); + } else { + try { + await redis.unsubscribe(`responses:${sessionId}:${requestId}`); + } catch (error) { + logger.error(`Failed to unsubscribe from responses:${sessionId}:${requestId}:`, error); + } + } + }; + + try { + // Get or create session manager for this session + if (enableMemoryManagement) { + // Try to get existing session manager or create a temporary one for this request + const existingSession = getSession(sessionId); + if (existingSession) { + sessionManager = existingSession; + } else { + // Create a lightweight session manager just for this request + sessionManager = createSessionManager(sessionId, redis, logger); + } + } + + // Subscribe to responses with tracking + const responseChannel = `responses:${sessionId}:${requestId}`; + const responseHandler = (message: string) => { try { const response = JSON.parse(message) as { status: number; @@ -620,28 +753,61 @@ export function initializeMcpApiHandler( logger.error("Failed to parse response message:", error); sendResponse(500, "Internal server error"); } + }; + + if (sessionManager) { + await subscribeWithTracking(redis, sessionManager, responseChannel, responseHandler, logger); + } else { + await redis.subscribe(responseChannel, responseHandler); } - ); - // Queue the request in Redis so that a subscriber can pick it up. - // One queue per session. - await redisPublisher.publish( - `requests:${sessionId}`, - JSON.stringify(serializedRequest) - ); - logger.log(`Published requests:${sessionId}`, serializedRequest); + // Queue the request in Redis so that a subscriber can pick it up + await redisPublisher.publish( + `requests:${sessionId}`, + JSON.stringify(serializedRequest) + ); + logger.log(`Published requests:${sessionId}`, serializedRequest); - // Set timeout after subscription is established - timeout = setTimeout(async () => { - await redis.unsubscribe(`responses:${sessionId}:${requestId}`); - sendResponse(408, "Request timed out"); - }, 10 * 1000); + // Set timeout with enhanced cleanup + const timeoutDuration = requestTimeout || 10 * 1000; + timeout = setTimeout(async () => { + await cleanup(); + sendResponse(408, "Request timed out"); + }, timeoutDuration); - res.on("close", async () => { - hasResponded = true; - clearTimeout(timeout); - await redis.unsubscribe(`responses:${sessionId}:${requestId}`); - }); + // Enhanced connection close handling + res.on("close", async () => { + hasResponded = true; + clearTimeout(timeout); + await cleanup(); + }); + + res.on("error", async (error) => { + logger.error(`Response error for ${sessionId}:${requestId}:`, error); + hasResponded = true; + clearTimeout(timeout); + await cleanup(); + }); + + } catch (error) { + logger.error(`Error in SSE message endpoint for ${sessionId}:${requestId}:`, error); + if (sessionManager) { + await sessionManager.cleanup(); + } + if (!hasResponded) { + sendResponse(500, "Internal server error"); + } + } + } else if (url.pathname === "/mcp-stats" && enableMemoryManagement) { + // Connection monitoring endpoint + res.statusCode = 200; + res.setHeader("Content-Type", "application/json"); + res.end(JSON.stringify({ + connectionStats: getConnectionStats(), + memoryUsage: process.memoryUsage(), + uptime: process.uptime(), + timestamp: new Date().toISOString() + })); } else { res.statusCode = 404; res.end("Not found"); diff --git a/src/lib/session-manager.ts b/src/lib/session-manager.ts new file mode 100644 index 0000000..2e155ac --- /dev/null +++ b/src/lib/session-manager.ts @@ -0,0 +1,178 @@ +import type { RedisClientType } from 'redis'; + +export interface SessionManager { + sessionId: string; + subscriptions: Set; + cleanup: () => Promise; + isActive: boolean; + createdAt: number; +} + +interface ConnectionMonitor { + activeConnections: number; + maxConnections: number; + connectionHistory: Map; +} + +const activeSessions = new Map(); +const monitor: ConnectionMonitor = { + activeConnections: 0, + maxConnections: 100, // Default limit + connectionHistory: new Map(), +}; + +export function createSessionManager( + sessionId: string, + redis: RedisClientType, + logger?: { log: (...args: unknown[]) => void; error: (...args: unknown[]) => void } +): SessionManager { + const subscriptions = new Set(); + let isActive = true; + const createdAt = Date.now(); + + const cleanup = async (): Promise => { + if (!isActive) return; + + try { + isActive = false; + + // Unsubscribe from all channels for this session + const channelsToUnsubscribe = Array.from(subscriptions); + if (channelsToUnsubscribe.length > 0) { + for (const channel of channelsToUnsubscribe) { + try { + await redis.unsubscribe(channel); + } catch (error) { + logger?.error(`Failed to unsubscribe from ${channel}:`, error); + } + } + logger?.log(`Cleaned up ${channelsToUnsubscribe.length} subscriptions for session ${sessionId}`); + } + + subscriptions.clear(); + activeSessions.delete(sessionId); + releaseConnection(sessionId); + } catch (error) { + logger?.error(`Failed to cleanup session ${sessionId}:`, error); + } + }; + + const manager: SessionManager = { + sessionId, + subscriptions, + cleanup, + isActive, + createdAt + }; + + activeSessions.set(sessionId, manager); + trackConnection(sessionId); + + return manager; +} + +// Enhanced subscription function with tracking +export async function subscribeWithTracking( + redis: RedisClientType, + sessionManager: SessionManager, + channel: string, + callback: (message: string) => void, + logger?: { log: (...args: unknown[]) => void; error: (...args: unknown[]) => void } +): Promise { + if (!sessionManager.isActive) { + throw new Error(`Cannot subscribe to ${channel}: session ${sessionManager.sessionId} is not active`); + } + + await redis.subscribe(channel, callback); + sessionManager.subscriptions.add(channel); + + logger?.log(`Subscribed to ${channel} for session ${sessionManager.sessionId}`); +} + +// Enhanced unsubscription with tracking +export async function unsubscribeWithTracking( + redis: RedisClientType, + sessionManager: SessionManager, + channel: string, + logger?: { log: (...args: unknown[]) => void; error: (...args: unknown[]) => void } +): Promise { + try { + await redis.unsubscribe(channel); + sessionManager.subscriptions.delete(channel); + + logger?.log(`Unsubscribed from ${channel} for session ${sessionManager.sessionId}`); + } catch (error) { + logger?.error(`Failed to unsubscribe from ${channel}:`, error); + } +} + +// Connection monitoring functions +export function checkConnectionLimit(): boolean { + return monitor.activeConnections < monitor.maxConnections; +} + +export function trackConnection(sessionId: string): void { + monitor.activeConnections++; + monitor.connectionHistory.set(sessionId, Date.now()); + + // Clean old connection history (older than 1 hour) + const oneHourAgo = Date.now() - (60 * 60 * 1000); + for (const [id, timestamp] of monitor.connectionHistory.entries()) { + if (timestamp < oneHourAgo) { + monitor.connectionHistory.delete(id); + } + } +} + +export function releaseConnection(sessionId: string): void { + monitor.activeConnections = Math.max(0, monitor.activeConnections - 1); + monitor.connectionHistory.delete(sessionId); +} + +export function getConnectionStats() { + return { + active: monitor.activeConnections, + max: monitor.maxConnections, + historySize: monitor.connectionHistory.size, + sessions: Array.from(activeSessions.keys()), + oldestSession: Math.min(...Array.from(activeSessions.values()).map(s => s.createdAt)), + }; +} + +export function setMaxConnections(limit: number): void { + monitor.maxConnections = limit; +} + +// Cleanup stale sessions (sessions older than specified time) +export async function cleanupStaleSessions(maxAgeMs: number = 60 * 60 * 1000): Promise { + const now = Date.now(); + const staleSessions: SessionManager[] = []; + + for (const session of activeSessions.values()) { + if (now - session.createdAt > maxAgeMs) { + staleSessions.push(session); + } + } + + const cleanupPromises = staleSessions.map(session => session.cleanup()); + await Promise.all(cleanupPromises); + + return staleSessions.length; +} + +// Cleanup all sessions (for graceful shutdown) +export async function cleanupAllSessions(): Promise { + const cleanupPromises = Array.from(activeSessions.values()).map(session => session.cleanup()); + await Promise.all(cleanupPromises); + console.log('All sessions cleaned up'); +} + +// Get session by ID +export function getSession(sessionId: string): SessionManager | undefined { + return activeSessions.get(sessionId); +} + +// Get all active sessions +export function getAllSessions(): SessionManager[] { + return Array.from(activeSessions.values()); +} \ No newline at end of file