diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb10195..8edfe65 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index dbef113..83eebd7 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [](https://github.com/tech4242/mcphawk/actions/workflows/ci.yml) [](https://codecov.io/gh/tech4242/mcphawk) - [](https://www.python.org/downloads/) + [](https://www.python.org/downloads/) [](https://fastapi.tiangolo.com/) [](https://vuejs.org/) [](https://github.com/astral-sh/ruff) @@ -11,38 +11,67 @@ [](https://opensource.org/licenses/MIT) -MCPHawk is a passive sniffer for **Model Context Protocol (MCP)** traffic, similar to Wireshark but MCP-focused. It's Wireshark x mcpinspector. - -- It captures JSON-RPC traffic between MCP clients and WebSocket/TCP-based MCP servers (IPv4 and IPv6) e.g. from any tool, agent, or LLM -- MCPHawk can reconstruct full JSON-RPC messages from raw TCP traffic without requiring a handshake. -- It captures traffic "on the wire" between any MCP client and serverβdoes not require client/server modification. - - - -## Features - -Non-exhaustive list: -- **Proper JSON-RPC 2.0 message type detection**: - - Requests (method + id) - - Responses (result/error + id) - - Notifications (method without id) - - Error responses -- **Auto-detect mode** - automatically discovers MCP traffic on any port without prior configuration -- **Flexible traffic filtering**: - - Monitor specific ports with `--port` - - Use custom BPF filters with `--filter` - - Auto-detect MCP traffic on all ports with `--auto-detect` -- **Chronological message display** - messages shown in order as captured -- **Message filtering** - view all, requests only, responses only, or notifications only -- **Optional ID-based pairing visualization** - see which requests and responses belong together -- **Real-time statistics** - message counts by type -- **Console-only mode** - use `mcphawk sniff` for terminal output without web UI -- **Historical log viewing** - use `mcphawk web --no-sniffer` to view past captures without active sniffing -- **Chill UX** - - dark mode π - - expand mode to directly see JSON withtout detailed view - - filtering - - always see if WS connection is up for live updates +MCPHawk is a passive network analyzer for **Model Context Protocol (MCP)** traffic, providing deep visibility into MCP client-server interactions. Think Wireshark meets mcpinspector, purpose-built for the MCP ecosystem. + +**Key Capabilities:** +- **Protocol-Aware Capture**: Understands MCP's JSON-RPC 2.0 transport layer, capturing and reassembling messages from raw TCP streams +- **Transport Agnostic**: Monitors MCP traffic across all standard transports +- **Zero-Configuration Monitoring**: Passively observes MCP communication without proxies, certificates, or modifications to clients/servers +- **Full Message Reconstruction**: Advanced TCP stream reassembly handles fragmented packets, chunked HTTP transfers, and SSE streams + + + +## Core Features + +### π MCP Protocol Analysis +- **Complete JSON-RPC 2.0 Support**: Correctly identifies and categorizes all MCP message types + - **Requests**: Method calls with unique IDs for correlation + - **Responses**: Success results and error responses with matching IDs + - **Notifications**: Fire-and-forget method calls without IDs + - **Batch Operations**: Support for JSON-RPC batch requests/responses +- **Transport-Specific Handling**: + - **HTTP/SSE**: Full support for MCP's streaming HTTP transport with Server-Sent Events + - **TCP Direct**: Raw TCP stream reconstruction for custom implementations + - **Chunked Transfer**: Handles HTTP chunked transfer encoding transparently +- **Protocol Compliance**: Validates JSON-RPC 2.0 structure and MCP-specific extensions + +### π Advanced Capture Capabilities +- **Auto-Discovery Mode**: Intelligently detects MCP traffic on any port using pattern matching +- **TCP Stream Reassembly**: Reconstructs complete messages from fragmented packets +- **Multi-Stream Tracking**: Simultaneously monitors multiple MCP client-server connections +- **IPv4/IPv6 Dual Stack**: Native support for both IP protocols +- **Zero-Copy Architecture**: Efficient packet processing without client/server overhead + +### π Analysis & Visualization +- **Real-Time Web Dashboard**: Live traffic visualization with WebSocket updates +- **Message Flow Visualization**: Track request-response pairs using JSON-RPC IDs +- **Traffic Statistics**: Method frequency, error rates, response times +- **Search & Filter**: Query by method name, message type, content patterns +- **Export Capabilities**: Save captured sessions for offline analysis + +### π οΈ Developer Experience +- **MCP Server Integration**: Query captured data using MCP protocol itself + - FastMCP-based implementation for maximum compatibility + - Available tools: `query_traffic`, `search_traffic`, `get_stats`, `list_methods` + - Supports both stdio and HTTP transports +- **Multiple Interfaces**: + - Web UI for interactive exploration + - CLI for scripting and automation + - MCP server for programmatic access +- **Flexible Deployment**: + - Standalone sniffer mode + - Integrated web + sniffer + - Historical log analysis without active capture + +### MCP Transport Support + +| Official MCP Transport | Protocol Version | Capture Support | Details | +|------------------------|------------------|:---------------:|---------| +| **stdio** | All versions | coming soon :) | secret | +| **HTTP** (Streamable HTTP) | 2025-03-26+ | β Full | HTTP POST with optional SSE streaming responses | +| **HTTP+SSE** (deprecated) | 2024-11-05 | β Full | Legacy transport with separate SSE endpoint | + +Disclaimer: TCP direct traffic with JSON-RPC is also captured and marked as unknown (should you have custom stuff you shouldn't) ## Comparison with Similar Tools @@ -50,23 +79,20 @@ Non-exhaustive list: |-----------------------------------------------|:---------:|:------------:|:---------:| | Passive sniffing (no proxy needed) | β | β | β | | MCP/JSON-RPC protocol awareness | β | β | β | -| Auto-detect MCP traffic on any port | β | β | β | +| SSE/Chunked HTTP support | β | β | β | +| TCP stream reassembly | β | β | β | +| Auto-detect MCP traffic | β | β | β | | Web UI for live/historical traffic | β | β | β | -| Can capture any traffic (not just via proxy) | β | β | β | | JSON-RPC message type detection | β | β | β | -| Message filtering by type | β | β | β | -| Console-only mode (no web UI needed) | β | β | β | -| Manual request crafting/testing | β | β | β | -| Interactive tool/prompt testing | β | β | β | -| Proxy/bridge between client/server | β | β | β | -| No client/server config changes required | β | β | β | -| General protocol analysis | β | β | β | -| MCP-specific features | β | β | β | +| MCP server for data access | β | β | β | +| No client/server config needed | β | β | β | +| Interactive testing/debugging | β | β | β | +| Proxy/MITM capabilities | β | β | β | **When to use each tool:** -- **MCPHawk**: Best for passively monitoring MCP traffic, debugging live connections, understanding protocol flow -- **mcpinspector**: Best for actively testing MCP servers, crafting custom requests, interactive debugging -- **Wireshark**: Best for general network analysis, non-MCP protocols, deep packet inspection +- **MCPHawk**: Passive monitoring, protocol analysis, debugging MCP implementations, understanding traffic patterns +- **mcpinspector**: Active testing, crafting requests, interactive debugging with proxy +- **Wireshark**: General network analysis, non-MCP protocols, packet-level inspection ## TLS/HTTPS Limitations @@ -81,11 +107,6 @@ MCPHawk captures **unencrypted** MCP traffic only. It cannot decrypt: - π **Troubleshooting local tools** - Monitor Claude Desktop, Cline, etc. with YOUR local MCP servers - π **Development/staging environments** - Where TLS is often disabled -**Not suitable for:** -- Production traffic analysis (usually encrypted) -- Cloud MCP services (HTTPS/WSS) -- Third-party MCP servers with TLS - ## Installation ### For Users @@ -141,8 +162,85 @@ sudo mcphawk web --port 3000 --host 0.0.0.0 --web-port 9000 # Enable debug output for troubleshooting sudo mcphawk sniff --port 3000 --debug sudo mcphawk web --port 3000 --debug + +# Start MCP server with Streamable HTTP transport (default) +mcphawk mcp --transport http --mcp-port 8765 + +# Start MCP server with stdio transport (for Claude Desktop integration) +mcphawk mcp --transport stdio + +# Start sniffer with integrated MCP server (HTTP transport) +sudo mcphawk sniff --port 3000 --with-mcp --mcp-transport http + +# Start web UI with integrated MCP server +sudo mcphawk web --port 3000 --with-mcp --mcp-transport http --mcp-port 8765 +``` + +## MCP Server Integration + +MCPHawk includes a built-in MCP server, allowing you to query captured traffic through the Model Context Protocol itself. This creates powerful possibilities: + +- **AI-Powered Analysis**: Connect Claude or other LLMs to analyze traffic patterns +- **Automated Monitoring**: Build agents that detect anomalies or specific behaviors +- **Integration Testing**: Programmatically verify MCP interactions in CI/CD pipelines + + + +### Available Tools + +The MCP server exposes these tools for traffic analysis: + +| Tool | Description | Parameters | +|------|-------------|------------| +| `query_traffic` | Fetch captured logs with pagination | `limit`, `offset` | +| `get_log` | Retrieve specific log entry | `log_id` | +| `search_traffic` | Search logs by content or type | `search_term`, `message_type`, `traffic_type`, `limit` | +| `get_stats` | Get traffic statistics | None | +| `list_methods` | List unique JSON-RPC methods | None | + +### Transport Options + +#### HTTP Transport (Development & Testing) + +The HTTP transport uses Server-Sent Events (SSE) for streaming responses: + +```bash +# Start MCP server +mcphawk mcp --transport http --mcp-port 8765 + +# Initialize session (note: returns SSE stream) +curl -N -X POST http://localhost:8765/mcp \ + -H 'Accept: text/event-stream' \ + -d '{"jsonrpc":"2.0","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}},"id":1}' + +# Example response (SSE format): +# event: message +# data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05",...}} ``` +#### stdio Transport (Production & Claude Desktop) + +For Claude Desktop integration: + +```json +{ + "mcpServers": { + "mcphawk": { + "command": "mcphawk", + "args": ["mcp", "--transport", "stdio"] + } + } +} +``` + +The stdio transport follows the standard MCP communication pattern: +1. Client sends `initialize` request +2. Server responds with capabilities +3. Client sends `initialized` notification +4. Normal tool calls can proceed + +See [examples/mcp_sdk_client.py](examples/mcp_sdk_client.py) for HTTP client example or [examples/stdio_client.py](examples/stdio_client.py) for stdio communication. + ## Platform Support ### Tested Platforms @@ -154,7 +252,7 @@ sudo mcphawk web --port 3000 --debug - Requires elevated privileges (`sudo`) on macOS/Linux for packet capture - Limited to localhost/loopback interface monitoring -- WebSocket capture requires traffic to be uncompressed +- Cannot decrypt TLS/HTTPS traffic (WSS, HTTPS) - IPv6 support requires explicit interface configuration on some systems - High traffic volumes (>1000 msgs/sec) may impact performance @@ -170,18 +268,21 @@ sudo mcphawk web --auto-detect - Ensure the MCP server/client is using localhost (127.0.0.1 or ::1) - Check if traffic is on the expected port - Try auto-detect mode to find MCP traffic: `--auto-detect` -- On macOS, ensure you're allowing the terminal to capture packets in System Preferences +- Verify traffic is unencrypted (not HTTPS/TLS) +- On macOS, ensure Terminal has permission to capture packets in System Preferences -**WebSocket Traffic Not Showing:** -- Verify the WebSocket connection is uncompressed -- Check if the server is using IPv6 (::1) - MCPHawk supports both IPv4 and IPv6 -- Ensure the WebSocket frames contain valid JSON-RPC messages +**SSE/HTTP Responses Not Showing:** +- Confirm the server uses standard SSE format (event: message\ndata: {...}\n\n) +- Check if responses use chunked transfer encoding +- Enable debug mode to see detailed packet analysis: `--debug` ## Potential Upcoming Features Vote for features by opening a GitHub issue! - [x] **Auto-detect MCP traffic** - Automatically discover MCP traffic on any port without prior configuration +- [x] **MCP Server Interface** - Expose captured traffic via MCP server for AI agents to query and analyze traffic patterns +- [ ] **Stdio capture** - eBPF Integration (Linux/macOS) Trace read/write system calls for pipe communication - [ ] **Protocol Version Detection** - Identify and display MCP protocol version from captured traffic - [ ] **Smart Search & Filtering** - Search by method name, params, or any JSON field with regex support - [ ] **Performance Analytics** - Request/response timing, method frequency charts, and latency distribution @@ -192,7 +293,6 @@ Vote for features by opening a GitHub issue! - [ ] **Interactive Replay** - Click any request to re-send it, edit and replay captured messages - [ ] **Real-time Alerts** - Alert on specific methods or error patterns with webhook support - [ ] **Visualization** - Sequence diagrams, resource heat maps, method dependency graphs -- [ ] **MCP Server Interface** - Expose captured traffic via MCP server for AI agents to query and analyze traffic patterns ... and a few more off the deep end: - [ ] **TLS/HTTPS Support (MITM Proxy Mode)** - Optional man-in-the-middle proxy with certificate installation for encrypted traffic @@ -236,10 +336,3 @@ mcphawk web --port 3000 cd frontend && npm run build:watch # Auto-rebuild on changes mcphawk web --port 3000 # In another terminal ``` - -### Testing with Dummy Server - -```bash -# Generate various MCP patterns -python3 examples/generate_traffic/generate_all.py -``` diff --git a/examples/branding/mcphawk_claudedesktop.png b/examples/branding/mcphawk_claudedesktop.png new file mode 100644 index 0000000..bf47532 Binary files /dev/null and b/examples/branding/mcphawk_claudedesktop.png differ diff --git a/examples/branding/mcphawk_screenshot.png b/examples/branding/mcphawk_screenshot.png index 86d3a61..ca5ba1b 100644 Binary files a/examples/branding/mcphawk_screenshot.png and b/examples/branding/mcphawk_screenshot.png differ diff --git a/examples/generate_traffic/README.md b/examples/generate_traffic/README.md deleted file mode 100644 index 3c36f15..0000000 --- a/examples/generate_traffic/README.md +++ /dev/null @@ -1,63 +0,0 @@ -# Traffic Generation Examples - -This directory contains example servers and clients for generating MCP traffic to test MCPHawk. - -## TCP-based MCP - -### Server -```bash -python3 tcp_server.py -``` -Starts a TCP MCP server on port 12345. - -### Client -```bash -python3 tcp_client.py -``` -Sends various MCP messages to the TCP server. - -## WebSocket-based MCP - -### Server -```bash -python3 ws_server.py -``` -Starts a WebSocket MCP server on port 8765. - -### Client -```bash -python3 ws_client.py -``` -Sends various MCP messages to the WebSocket server. - -## Generate All Traffic - -To generate both TCP and WebSocket traffic for testing: - -```bash -python3 generate_all.py -``` - -This will: -1. Start both TCP and WebSocket servers -2. Send a variety of MCP messages to each -3. Display the traffic being generated -4. Clean up when done - -## Testing with MCPHawk - -In another terminal, run MCPHawk to capture the traffic: - -```bash -# Capture TCP traffic -sudo mcphawk sniff --port 12345 - -# Capture WebSocket traffic -sudo mcphawk sniff --port 8765 - -# Capture both -sudo mcphawk sniff --filter "tcp port 12345 or tcp port 8765" - -# Or use auto-detect -sudo mcphawk sniff --auto-detect -``` \ No newline at end of file diff --git a/examples/generate_traffic/generate_all.py b/examples/generate_traffic/generate_all.py deleted file mode 100755 index c12dc0a..0000000 --- a/examples/generate_traffic/generate_all.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python3 -"""Generate both TCP and WebSocket MCP traffic for testing MCPHawk.""" - -import asyncio -import contextlib -import json -import logging -import socket -import subprocess -import sys -import time -from pathlib import Path - -import websockets - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(message)s' -) -logger = logging.getLogger(__name__) - - -def run_tcp_server(): - """Run the TCP server in a subprocess.""" - script_path = Path(__file__).parent / "tcp_server.py" - return subprocess.Popen([sys.executable, str(script_path)]) - - -def run_ws_server(): - """Run the WebSocket server in a subprocess.""" - script_path = Path(__file__).parent / "ws_server.py" - return subprocess.Popen([sys.executable, str(script_path)]) - - -def send_tcp_traffic(): - """Send TCP MCP traffic.""" - logger.info("Sending TCP traffic...") - - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(('localhost', 12345)) - - messages = [ - {"jsonrpc": "2.0", "method": "initialize", "params": {"protocolVersion": "2024-11-05"}, "id": 1}, - {"jsonrpc": "2.0", "method": "tools/list", "id": 2}, - {"jsonrpc": "2.0", "method": "notifications/tcp_test", "params": {"source": "tcp"}}, - {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "test_tool", "arguments": {}}, "id": 3}, - ] - - for msg in messages: - data = json.dumps(msg).encode('utf-8') - sock.sendall(data) - logger.info(f" TCP sent: {msg.get('method')}") - - # Read response if it has an ID - if "id" in msg: - try: - response = sock.recv(1024) - if response: - logger.info(" TCP received response") - except Exception: - pass - - time.sleep(0.2) - - # Keep connection open a bit longer - time.sleep(1) - sock.close() - - except Exception as e: - logger.error(f"TCP client error: {e}") - - -async def send_ws_traffic(): - """Send WebSocket MCP traffic.""" - logger.info("Sending WebSocket traffic...") - - try: - async with websockets.connect("ws://localhost:8765", compression=None) as ws: - messages = [ - {"jsonrpc": "2.0", "method": "initialize", "params": {"protocolVersion": "2024-11-05"}, "id": 1}, - {"jsonrpc": "2.0", "method": "tools/list", "id": 2}, - {"jsonrpc": "2.0", "method": "notifications/ws_test", "params": {"source": "websocket"}}, - {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "calculator", "arguments": {"a": 10, "b": 20}}, "id": 3}, - ] - - for msg in messages: - await ws.send(json.dumps(msg)) - logger.info(f" WS sent: {msg.get('method')}") - - # Wait for response if it has an ID - if "id" in msg: - await ws.recv() - logger.info(" WS received response") - else: - await asyncio.sleep(0.2) - - except Exception as e: - logger.error(f"WebSocket client error: {e}") - - -def check_port(port): - """Check if a port is available.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - result = sock.connect_ex(('localhost', port)) - sock.close() - return result != 0 - - -async def main(): - """Generate all traffic types.""" - logger.info("MCPHawk Traffic Generator") - logger.info("========================\n") - - # Check if ports are available - if not check_port(12345): - logger.error("Port 12345 is already in use. Please stop the existing TCP server.") - return - - if not check_port(8765): - logger.error("Port 8765 is already in use. Please stop the existing WebSocket server.") - return - - logger.info("Starting servers...") - - # Start servers - tcp_server = run_tcp_server() - ws_server = run_ws_server() - - # Give servers time to start - logger.info("Waiting for servers to start...") - await asyncio.sleep(2) - - try: - # Send TCP traffic - send_tcp_traffic() - - # Send WebSocket traffic - await send_ws_traffic() - - logger.info("\nβ All traffic sent successfully!") - logger.info("\nServers will continue running. Press Ctrl+C to stop.") - - # Keep running - await asyncio.Future() - - except KeyboardInterrupt: - logger.info("\nStopping servers...") - finally: - # Clean up - tcp_server.terminate() - ws_server.terminate() - tcp_server.wait() - ws_server.wait() - logger.info("Servers stopped.") - - -if __name__ == "__main__": - with contextlib.suppress(KeyboardInterrupt): - asyncio.run(main()) - diff --git a/examples/generate_traffic/tcp_client.py b/examples/generate_traffic/tcp_client.py deleted file mode 100755 index 83d3785..0000000 --- a/examples/generate_traffic/tcp_client.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -"""TCP MCP client for testing.""" - -import json -import socket -import time - - -def send_mcp_message(sock, message): - """Send a JSON-RPC message.""" - data = json.dumps(message).encode('utf-8') - sock.sendall(data) - print(f"Sent: {message.get('method', message.get('result', 'response'))}") - - # Read response if message has an ID - if "id" in message: - try: - response = sock.recv(1024) - if response: - print(" Received response") - except Exception: - pass - - -def main(): - """Connect to TCP MCP server and send test messages.""" - print("Connecting to TCP MCP server on localhost:12345...") - - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.connect(('localhost', 12345)) - print("Connected!") - - # Send initialize - send_mcp_message(sock, { - "jsonrpc": "2.0", - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {} - }, - "id": 1 - }) - time.sleep(0.5) - - # Send tools/list - send_mcp_message(sock, { - "jsonrpc": "2.0", - "method": "tools/list", - "id": 2 - }) - time.sleep(0.5) - - # Send a notification (no id) - send_mcp_message(sock, { - "jsonrpc": "2.0", - "method": "notifications/progress", - "params": { - "progress": 50, - "operation": "processing" - } - }) - time.sleep(0.5) - - # Send tools/call - send_mcp_message(sock, { - "jsonrpc": "2.0", - "method": "tools/call", - "params": { - "name": "calculator", - "arguments": {"a": 5, "b": 3} - }, - "id": 3 - }) - time.sleep(0.5) - - # Send a batch of messages quickly - print("\nSending batch of messages...") - for i in range(4, 8): - send_mcp_message(sock, { - "jsonrpc": "2.0", - "method": f"test/message_{i}", - "params": {"value": i * 10}, - "id": i - }) - time.sleep(0.1) - - print("\nAll messages sent!") - # Keep connection open briefly to avoid reset - time.sleep(2) - - except ConnectionRefusedError: - print("Error: Could not connect to server. Make sure tcp_server.py is running.") - except Exception as e: - print(f"Error: {e}") - - -if __name__ == "__main__": - main() - diff --git a/examples/generate_traffic/tcp_server.py b/examples/generate_traffic/tcp_server.py deleted file mode 100755 index 98cfcc4..0000000 --- a/examples/generate_traffic/tcp_server.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -import socket -import threading - -HOST = "127.0.0.1" -PORT = 12345 # MCPHawk should sniff this port - - -def handle_client(conn, addr): - print(f"[DUMMY MCP] Connection from {addr}") - try: - while True: - data = conn.recv(1024) - if not data: - break - - raw_msg = data.decode(errors="ignore").strip() - print(f"[DUMMY MCP] Received: {raw_msg}") - - try: - # Parse incoming JSON-RPC request - request = json.loads(raw_msg) - request_id = request.get("id") - - # Build realistic JSON-RPC response - response = { - "jsonrpc": "2.0", - "result": "ok", - "id": request_id # echo back same id if present - } - - except json.JSONDecodeError: - print("[DUMMY MCP] Invalid JSON received, sending error response") - response = { - "jsonrpc": "2.0", - "error": {"code": -32700, "message": "Parse error"}, - "id": None - } - - # Send back response - conn.sendall((json.dumps(response) + "\n").encode()) - - finally: - conn.close() - - -def start_server(): - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server.bind((HOST, PORT)) - server.listen() - print(f"[DUMMY MCP] Listening on {HOST}:{PORT}") - while True: - conn, addr = server.accept() - threading.Thread(target=handle_client, args=(conn, addr), daemon=True).start() - - -if __name__ == "__main__": - start_server() - diff --git a/examples/generate_traffic/test_capture.py b/examples/generate_traffic/test_capture.py deleted file mode 100755 index 4a49c7a..0000000 --- a/examples/generate_traffic/test_capture.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -"""Test that MCPHawk can capture both TCP and WebSocket traffic.""" - -import json -import sqlite3 - - -def check_captured_messages(): - """Check the MCPHawk database for captured messages.""" - try: - conn = sqlite3.connect("mcphawk_logs.db") - cursor = conn.cursor() - - # Get total count - cursor.execute("SELECT COUNT(*) FROM logs") - total = cursor.fetchone()[0] - print(f"Total messages in database: {total}") - - # Check TCP messages (port 12345) - cursor.execute(""" - SELECT COUNT(*) FROM logs - WHERE (src_port = 12345 OR dst_port = 12345) - AND message LIKE '%jsonrpc%' - """) - tcp_count = cursor.fetchone()[0] - print(f"TCP MCP messages (port 12345): {tcp_count}") - - # Check WebSocket messages (port 8765) - cursor.execute(""" - SELECT COUNT(*) FROM logs - WHERE (src_port = 8765 OR dst_port = 8765) - AND message LIKE '%jsonrpc%' - """) - ws_count = cursor.fetchone()[0] - print(f"WebSocket MCP messages (port 8765): {ws_count}") - - # Show sample messages - if tcp_count > 0: - print("\nSample TCP messages:") - cursor.execute(""" - SELECT message FROM logs - WHERE (src_port = 12345 OR dst_port = 12345) - AND message LIKE '%jsonrpc%' - LIMIT 3 - """) - for row in cursor: - msg = json.loads(row[0]) - print(f" - {msg.get('method', msg.get('result', '?'))}") - - if ws_count > 0: - print("\nSample WebSocket messages:") - cursor.execute(""" - SELECT message FROM logs - WHERE (src_port = 8765 OR dst_port = 8765) - AND message LIKE '%jsonrpc%' - LIMIT 3 - """) - for row in cursor: - msg = json.loads(row[0]) - print(f" - {msg.get('method', msg.get('result', '?'))}") - - conn.close() - - # Summary - print("\n" + "="*50) - if tcp_count > 0 and ws_count > 0: - print("β SUCCESS: Both TCP and WebSocket MCP traffic captured!") - elif tcp_count > 0: - print("β οΈ Only TCP traffic captured") - elif ws_count > 0: - print("β οΈ Only WebSocket traffic captured") - else: - print("β No MCP traffic captured") - - except Exception as e: - print(f"Error checking database: {e}") - - -if __name__ == "__main__": - print("MCPHawk Capture Test") - print("="*50) - print("\nChecking captured messages...") - print("(Make sure MCPHawk is running and traffic has been generated)\n") - - check_captured_messages() - diff --git a/examples/generate_traffic/ws_client.py b/examples/generate_traffic/ws_client.py deleted file mode 100755 index 8379800..0000000 --- a/examples/generate_traffic/ws_client.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python3 -"""WebSocket MCP client for testing.""" - -import asyncio -import json -import logging - -import websockets - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -async def send_mcp_message(websocket, message): - """Send a JSON-RPC message and optionally wait for response.""" - await websocket.send(json.dumps(message)) - method = message.get('method', message.get('result', 'response')) - logger.info(f"Sent: {method}") - - # If it has an ID, wait for response - if "id" in message: - response = await websocket.recv() - response_data = json.loads(response) - logger.info(f"Received response: {response_data.get('result', response_data.get('error'))}") - return response_data - - -async def main(): - """Connect to WebSocket MCP server and send test messages.""" - uri = "ws://localhost:8765" - logger.info(f"Connecting to WebSocket MCP server at {uri}...") - - try: - async with websockets.connect(uri, compression=None) as websocket: - logger.info("Connected!") - - # Send initialize - await send_mcp_message(websocket, { - "jsonrpc": "2.0", - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {} - }, - "id": 1 - }) - - # Send tools/list - await send_mcp_message(websocket, { - "jsonrpc": "2.0", - "method": "tools/list", - "id": 2 - }) - - # Send a notification (no id, no response expected) - await send_mcp_message(websocket, { - "jsonrpc": "2.0", - "method": "notifications/progress", - "params": { - "progress": 50, - "operation": "processing" - } - }) - await asyncio.sleep(0.5) # Small delay after notification - - # Send tools/call - await send_mcp_message(websocket, { - "jsonrpc": "2.0", - "method": "tools/call", - "params": { - "name": "calculator", - "arguments": {"a": 5, "b": 3} - }, - "id": 3 - }) - - # Send a large message to test extended length - large_data = "x" * 1000 - await send_mcp_message(websocket, { - "jsonrpc": "2.0", - "method": "test/large_message", - "params": {"data": large_data}, - "id": 4 - }) - - # Send a batch of messages quickly - logger.info("\nSending batch of messages...") - tasks = [] - for i in range(5, 10): - message = { - "jsonrpc": "2.0", - "method": f"test/message_{i}", - "params": {"value": i * 10}, - "id": i - } - tasks.append(send_mcp_message(websocket, message)) - - # Wait for all responses - await asyncio.gather(*tasks) - - # Send one more notification before closing - await send_mcp_message(websocket, { - "jsonrpc": "2.0", - "method": "notifications/closing", - "params": {"reason": "test complete"} - }) - - logger.info("\nAll messages sent!") - await asyncio.sleep(0.5) - - except websockets.exceptions.WebSocketException as e: - logger.error(f"WebSocket error: {e}") - except ConnectionRefusedError: - logger.error("Could not connect to server. Make sure ws_server.py is running.") - except Exception as e: - logger.error(f"Error: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) - diff --git a/examples/generate_traffic/ws_server.py b/examples/generate_traffic/ws_server.py deleted file mode 100755 index 759dbea..0000000 --- a/examples/generate_traffic/ws_server.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -"""WebSocket MCP server for testing.""" - -import asyncio -import json -import logging - -import websockets - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -async def handle_mcp_message(websocket, message): - """Handle incoming MCP message and send response.""" - try: - data = json.loads(message) - method = data.get("method") - msg_id = data.get("id") - - logger.info(f"Received: {method} (id: {msg_id})") - - # Only respond to requests (with id), not notifications - if msg_id is None: - logger.info(f"Notification received: {method}") - return - - # Generate response based on method - if method == "initialize": - response = { - "jsonrpc": "2.0", - "result": { - "protocolVersion": "2024-11-05", - "capabilities": { - "tools": {"listChanged": True}, - "resources": {"subscribe": True} - }, - "serverInfo": { - "name": "test-ws-server", - "version": "1.0.0" - } - }, - "id": msg_id - } - elif method == "tools/list": - response = { - "jsonrpc": "2.0", - "result": { - "tools": [ - { - "name": "calculator", - "description": "Perform calculations", - "inputSchema": { - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"} - } - } - }, - { - "name": "echo", - "description": "Echo back input", - "inputSchema": { - "type": "object", - "properties": { - "message": {"type": "string"} - } - } - } - ] - }, - "id": msg_id - } - elif method == "tools/call": - params = data.get("params", {}) - tool_name = params.get("name") - args = params.get("arguments", {}) - - if tool_name == "calculator": - result = args.get("a", 0) + args.get("b", 0) - response = { - "jsonrpc": "2.0", - "result": {"value": result}, - "id": msg_id - } - else: - response = { - "jsonrpc": "2.0", - "result": {"echo": str(args)}, - "id": msg_id - } - else: - # Generic response - response = { - "jsonrpc": "2.0", - "result": {"status": "ok", "method": method}, - "id": msg_id - } - - await websocket.send(json.dumps(response)) - logger.info(f"Sent response for {method}") - - except json.JSONDecodeError: - logger.error(f"Invalid JSON: {message}") - except Exception as e: - logger.error(f"Error handling message: {e}") - if msg_id: - error_response = { - "jsonrpc": "2.0", - "error": { - "code": -32603, - "message": str(e) - }, - "id": msg_id - } - await websocket.send(json.dumps(error_response)) - - -async def mcp_server(websocket): - """Handle WebSocket connection.""" - logger.info("Client connected") - - try: - async for message in websocket: - await handle_mcp_message(websocket, message) - except websockets.exceptions.ConnectionClosed: - logger.info("Client disconnected") - except Exception as e: - logger.error(f"Connection error: {e}") - - -async def main(): - """Start WebSocket MCP server.""" - port = 8765 - logger.info(f"Starting WebSocket MCP server on ws://localhost:{port}") - - async with websockets.serve(mcp_server, "localhost", port, compression=None): - logger.info("Server ready. Press Ctrl+C to stop.") - await asyncio.Future() # Run forever - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - logger.info("Server stopped.") - diff --git a/examples/http_sse_example.py b/examples/http_sse_example.py new file mode 100644 index 0000000..c484c62 --- /dev/null +++ b/examples/http_sse_example.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +""" +Proper HTTP+SSE MCP client traffic generator. + +This simulates the legacy HTTP+SSE transport pattern as documented: +1. GET request to /sse endpoint to establish SSE connection +2. Server sends "endpoint" event with the message endpoint URL +3. Client POSTs JSON-RPC messages to the message endpoint + +This example creates traffic that MCPHawk can properly detect as HTTP+SSE. +""" + +import json +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer + +import requests + + +class MockHTTPSSEServer(BaseHTTPRequestHandler): + """Mock server that implements HTTP+SSE pattern.""" + + def log_message(self, format, *args): + """Suppress default logging.""" + pass + + def do_GET(self): + """Handle GET request for SSE connection.""" + if self.path == '/sse': + print("[Mock Server] Received GET /sse - sending SSE response") + self.send_response(200) + self.send_header('Content-Type', 'text/event-stream') + self.send_header('Cache-Control', 'no-cache') + self.send_header('Connection', 'keep-alive') + self.end_headers() + + # Send the endpoint event as per HTTP+SSE spec + endpoint_event = 'event: endpoint\ndata: {"url": "/messages"}\n\n' + self.wfile.write(endpoint_event.encode()) + self.wfile.flush() + + # For this example, close after sending endpoint + # Real servers would keep the connection open for streaming + return + else: + self.send_error(404) + + def do_POST(self): + """Handle POST request to message endpoint.""" + if self.path == '/messages': + print("[Mock Server] Received POST /messages") + content_length = int(self.headers['Content-Length']) + post_data = self.rfile.read(content_length) + + # Parse the JSON-RPC request + try: + request = json.loads(post_data) + print(f"[Mock Server] Request: {request}") + + # Send a simple response + response = { + "jsonrpc": "2.0", + "result": {"initialized": True}, + "id": request.get("id") + } + + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + except Exception as e: + print(f"[Mock Server] Error: {e}") + self.send_error(400) + else: + self.send_error(404) + + +def run_mock_server(port=8766): + """Run the mock HTTP+SSE server in a thread.""" + server = HTTPServer(('localhost', port), MockHTTPSSEServer) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + return server + + +def simulate_http_sse_client(server_port=8766): + """Simulate HTTP+SSE client traffic pattern.""" + + print("\nSimulating HTTP+SSE MCP Client (Legacy Pattern)") + print("=" * 50) + + server_url = f"http://localhost:{server_port}" + + # Step 1: Establish SSE connection with GET request + print("\n1. Establishing SSE connection...") + print(f" GET {server_url}/sse") + print(" Accept: text/event-stream") + + # Use requests library for better control + session = requests.Session() + + endpoint_url = None + + try: + # Make GET request with SSE accept header + headers = {'Accept': 'text/event-stream'} + response = session.get(f"{server_url}/sse", headers=headers, stream=True, timeout=2) + + print(f" Response: {response.status_code} {response.reason}") + print(f" Content-Type: {response.headers.get('Content-Type')}") + + # Read the endpoint event with timeout on iter_lines + try: + for line in response.iter_lines(decode_unicode=True, chunk_size=1): + if line: + print(f" SSE: {line}") + if line.startswith('data:'): + data = json.loads(line[5:].strip()) + endpoint_url = data.get('url') + break + except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError): + print(" SSE connection closed/timed out (expected)") + + response.close() # Close the streaming connection + + except Exception as e: + print(f" Error during GET: {e}") + + # Always try the POST request, even if GET had issues + if not endpoint_url: + endpoint_url = "/messages" # Default endpoint for HTTP+SSE + print(f"\n2. Using default endpoint URL: {endpoint_url}") + else: + print(f"\n2. Server sent endpoint URL: {endpoint_url}") + + # Step 2: Send JSON-RPC request to the endpoint + print("\n3. Sending JSON-RPC request to endpoint...") + print(f" POST {server_url}{endpoint_url}") + print(" Content-Type: application/json") + + try: + # Send initialize request + initialize_request = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "http-sse-test", + "version": "1.0.0" + } + }, + "id": 1 + } + + # Important: No Accept header with dual types for HTTP+SSE POST + post_response = session.post( + f"{server_url}{endpoint_url}", + json=initialize_request, + headers={'Content-Type': 'application/json'}, + timeout=5 + ) + + print(f" Response: {post_response.status_code}") + if post_response.status_code == 200: + print(f" Result: {post_response.json()}") + + except Exception as e: + print(f" Error during POST: {e}") + + print("\n" + "=" * 50) + print("HTTP+SSE pattern demonstration complete") + print("Check MCPHawk to see the detected transport type:") + print("- GET /sse with Accept: text/event-stream β HTTP+SSE") + print("- Server sends 'endpoint' event β Confirms HTTP+SSE") + print("- POST to endpoint without dual Accept β HTTP+SSE pattern") + + +if __name__ == "__main__": + print("HTTP+SSE MCP Client Example (Proper Implementation)") + print("This demonstrates the legacy HTTP+SSE transport pattern") + print("with a mock server that properly implements the protocol\n") + + # Start mock server + port = 8766 + print(f"Starting mock HTTP+SSE server on port {port}...") + server = run_mock_server(port) + time.sleep(1) # Give server time to start + + # Run client simulation + simulate_http_sse_client(port) + + # Keep server running briefly for any remaining packets + time.sleep(1) + print("\nDone!") diff --git a/examples/mcp_sdk_client.py b/examples/mcp_sdk_client.py new file mode 100755 index 0000000..20637fc --- /dev/null +++ b/examples/mcp_sdk_client.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +"""Test MCPHawk server using the official MCP SDK client.""" + +import asyncio +import json + +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + + +async def test_with_sdk_client(): + """Test using the SDK's official client.""" + + # Connect to the server + async with streamablehttp_client("http://localhost:8765/mcp") as (read_stream, write_stream, session_id): + print(f"Connected with session ID: {session_id}") + + # Create a session + async with ClientSession(read_stream, write_stream) as session: + # Initialize + print("\n1. Initializing...") + await session.initialize() + print("Initialized successfully") + + # List tools + print("\n2. Listing tools...") + tools_result = await session.list_tools() + print(f"Available tools: {len(tools_result.tools)}") + for tool in tools_result.tools: + print(f" - {tool.name}: {tool.description}") + + # Call a tool + print("\n3. Calling get_stats...") + result = await session.call_tool("get_stats", arguments={}) + print("Result:") + for content in result.content: + print(f" {content.text}") + + # Try query_traffic + print("\n4. Querying recent traffic...") + result = await session.call_tool("query_traffic", arguments={"limit": 5}) + print("Result:") + for content in result.content: + data = json.loads(content.text) + print(f" Found {len(data)} log entries") + + +if __name__ == "__main__": + print("Testing MCPHawk MCP Server with SDK Client") + print("==========================================") + print("Make sure the server is running with:") + print(" mcphawk mcp --transport http --mcp-port 8765") + print() + + try: + asyncio.run(test_with_sdk_client()) + except Exception as e: + print(f"Error: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + diff --git a/examples/stdio_client.py b/examples/stdio_client.py new file mode 100644 index 0000000..b483088 --- /dev/null +++ b/examples/stdio_client.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +Example stdio client for MCPHawk MCP server. + +This demonstrates how to communicate with MCPHawk's MCP server using the stdio transport. +The MCP protocol requires: +1. Initialize request +2. Initialized notification +3. Then you can make tool calls +""" + +import json +import queue +import subprocess +import threading +from typing import Any, Optional + + +class MCPHawkStdioClient: + """Client for communicating with MCPHawk MCP server over stdio.""" + + def __init__(self, debug: bool = False): + self.debug = debug + self.proc = None + self.stderr_queue = queue.Queue() + self.request_id = 0 + + def connect(self) -> bool: + """Start the MCP server process and initialize connection.""" + try: + # Start the MCP server + self.proc = subprocess.Popen( + ["mcphawk", "mcp", "--transport", "stdio"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=0 # Unbuffered + ) + + # Start stderr reader thread + self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True) + self.stderr_thread.start() + + # Send initialize request + init_response = self._send_request({ + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "mcphawk-stdio-client", "version": "1.0"} + } + }) + + if not init_response or "error" in init_response: + print(f"Failed to initialize: {init_response}") + return False + + # Send initialized notification + self._send_notification({ + "method": "notifications/initialized", + "params": {} + }) + + if self.debug: + print(f"Connected to server: {init_response['result']['serverInfo']}") + + return True + + except Exception as e: + print(f"Failed to connect: {e}") + return False + + def _read_stderr(self): + """Read stderr in a separate thread.""" + while self.proc and self.proc.poll() is None: + line = self.proc.stderr.readline() + if line: + self.stderr_queue.put(line.strip()) + + def _send_request(self, request: dict[str, Any]) -> Optional[dict[str, Any]]: + """Send a JSON-RPC request and wait for response.""" + self.request_id += 1 + request["jsonrpc"] = "2.0" + request["id"] = self.request_id + + request_str = json.dumps(request) + if self.debug: + print(f">>> {request_str}") + + self.proc.stdin.write(request_str + "\n") + self.proc.stdin.flush() + + # Read response + response_line = self.proc.stdout.readline() + if response_line: + try: + response = json.loads(response_line) + if self.debug: + print(f"<<< {json.dumps(response, indent=2)}") + return response + except json.JSONDecodeError as e: + print(f"Failed to decode response: {e}") + print(f"Raw: {response_line}") + return None + return None + + def _send_notification(self, notification: dict[str, Any]) -> None: + """Send a JSON-RPC notification (no response expected).""" + notification["jsonrpc"] = "2.0" + + notification_str = json.dumps(notification) + if self.debug: + print(f">>> {notification_str}") + + self.proc.stdin.write(notification_str + "\n") + self.proc.stdin.flush() + + def list_tools(self) -> Optional[list]: + """Get list of available tools.""" + response = self._send_request({ + "method": "tools/list", + "params": {} + }) + + if response and "result" in response: + return response["result"]["tools"] + return None + + def call_tool(self, tool_name: str, arguments: Optional[dict[str, Any]] = None) -> Optional[Any]: + """Call a tool with given arguments.""" + response = self._send_request({ + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments or {} + } + }) + + if response and "result" in response: + # Extract the text content from the response + content = response["result"]["content"] + if content and len(content) > 0: + text = content[0]["text"] + try: + # Try to parse as JSON + return json.loads(text) + except json.JSONDecodeError: + # Return as plain text if not JSON + return text + return None + + def close(self): + """Close the connection.""" + if self.proc: + self.proc.terminate() + self.proc.wait() + self.proc = None + + +def main(): + """Example usage of the MCPHawk stdio client.""" + client = MCPHawkStdioClient(debug=True) + + print("Connecting to MCPHawk MCP server...") + if not client.connect(): + print("Failed to connect!") + return + + print("\n1. Listing available tools:") + tools = client.list_tools() + if tools: + for tool in tools: + print(f" - {tool['name']}: {tool['description']}") + + print("\n2. Getting traffic statistics:") + stats = client.call_tool("get_stats") + if stats: + print(f" Total logs: {stats['total']}") + print(f" Requests: {stats['requests']}") + print(f" Responses: {stats['responses']}") + print(f" Notifications: {stats['notifications']}") + print(f" Errors: {stats['errors']}") + + print("\n3. Querying recent traffic:") + logs = client.call_tool("query_traffic", {"limit": 5}) + if logs: + print(f" Found {len(logs)} recent log entries") + for log in logs: + msg_preview = log['message'][:50] + "..." if len(log['message']) > 50 else log['message'] + print(f" - {log['timestamp']}: {msg_preview}") + + print("\n4. Listing captured methods:") + methods = client.call_tool("list_methods") + if methods: + print(f" Found {len(methods)} unique methods:") + for method in methods[:10]: # Show first 10 + print(f" - {method}") + if len(methods) > 10: + print(f" ... and {len(methods) - 10} more") + + print("\nClosing connection...") + client.close() + + +if __name__ == "__main__": + main() diff --git a/examples/streamable_http_example.py b/examples/streamable_http_example.py new file mode 100644 index 0000000..aaeaa66 --- /dev/null +++ b/examples/streamable_http_example.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +Example Streamable HTTP MCP client traffic generator. + +This demonstrates the Streamable HTTP transport pattern for testing MCPHawk's transport detection. +Streamable HTTP uses: +1. POST request with dual Accept headers (application/json, text/event-stream) +2. Server can respond with either JSON or SSE +""" + +import json +import urllib.error +import urllib.request + + +def simulate_streamable_http_client(): + """Simulate Streamable HTTP client traffic pattern.""" + + print("Simulating Streamable HTTP MCP Client") + print("=" * 50) + + server_url = "http://localhost:8765" + + print("\n1. Sending request with dual Accept headers (Streamable HTTP pattern)...") + print(f" POST {server_url}/mcp") + print(" Accept: application/json, text/event-stream") + print(" Content-Type: application/json") + + # Send a sample initialize request + initialize_request = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", # New protocol version + "capabilities": {}, + "clientInfo": { + "name": "streamable-http-test", + "version": "1.0.0" + } + }, + "id": 1 + } + + try: + data = json.dumps(initialize_request).encode('utf-8') + req = urllib.request.Request( + f"{server_url}/mcp", + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" # Key difference! + } + ) + + print(f"\n Request body: {json.dumps(initialize_request, indent=2)}") + + with urllib.request.urlopen(req) as response: + print(f"\n Response: {response.status}") + content_type = response.headers.get('Content-Type', '') + print(f" Content-Type: {content_type}") + + if 'text/event-stream' in content_type: + print(" Server returned SSE response (streaming)") + else: + print(" Server returned JSON response") + + except urllib.error.URLError as e: + print(f" Connection failed: {e}") + + # Send another request that might get a different response type + print("\n2. Sending tool call request...") + print(f" POST {server_url}/mcp") + print(" Accept: application/json, text/event-stream") + + tool_request = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "long_running_tool", + "arguments": {} + }, + "id": 2 + } + + try: + data = json.dumps(tool_request).encode('utf-8') + req = urllib.request.Request( + f"{server_url}/mcp", + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + } + ) + + with urllib.request.urlopen(req) as response: + print(f"\n Response: {response.status}") + content_type = response.headers.get('Content-Type', '') + print(f" Content-Type: {content_type}") + + except urllib.error.URLError as e: + print(f" Connection failed: {e}") + + print("\n" + "=" * 50) + print("Streamable HTTP pattern demonstration complete") + print("Check MCPHawk to see how it detected the transport type:") + print("- POST with Accept: application/json, text/event-stream β Streamable HTTP") + + +if __name__ == "__main__": + print("Streamable HTTP MCP Client Example") + print("This demonstrates the Streamable HTTP transport pattern") + print("This should work with our MCP server\n") + + simulate_streamable_http_client() diff --git a/examples/test_mcp_http.sh b/examples/test_mcp_http.sh new file mode 100755 index 0000000..d3f58f5 --- /dev/null +++ b/examples/test_mcp_http.sh @@ -0,0 +1,177 @@ +#!/bin/bash + +# Test MCPHawk MCP Server with various requests +# Make sure MCPHawk is running with: +# sudo mcphawk web --auto-detect --with-mcp --mcp-transport http --mcp-port 8765 --debug + +SESSION_ID="test-session-$(date +%s)" +MCP_URL="http://localhost:8765/mcp" + +echo "Testing MCPHawk MCP Server with session: $SESSION_ID" +echo "================================================" + +# 1. Initialize session +echo -e "\n1. Initializing MCP session..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0" + } + }, + "id": 1 + }' | jq . + +sleep 1 + +# 2. Send initialized notification +echo -e "\n2. Sending initialized notification..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }' | jq . + +sleep 1 + +# 3. List available tools +echo -e "\n3. Listing available tools..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 2 + }' | jq . + +sleep 1 + +# 4. Get traffic statistics +echo -e "\n4. Getting traffic statistics..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "get_stats", + "arguments": {} + }, + "id": 3 + }' | jq . + +sleep 1 + +# 5. Query recent traffic +echo -e "\n5. Querying recent traffic (limit 10)..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "query_traffic", + "arguments": { + "limit": 10 + } + }, + "id": 4 + }' | jq . + +sleep 1 + +# 6. List unique methods captured +echo -e "\n6. Listing unique JSON-RPC methods..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "list_methods", + "arguments": {} + }, + "id": 5 + }' | jq . + +sleep 1 + +# 7. Search for specific traffic +echo -e "\n7. Searching for 'initialize' in traffic..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "search_traffic", + "arguments": { + "search_term": "initialize" + } + }, + "id": 6 + }' | jq . + +sleep 1 + +# 8. Test error handling - call non-existent tool +echo -e "\n8. Testing error handling (calling non-existent tool)..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "non_existent_tool", + "arguments": {} + }, + "id": 7 + }' | jq . + +sleep 1 + +# 9. Send a notification (no ID, no response expected) +echo -e "\n9. Sending a notification..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { + "progress": 50, + "message": "Test progress notification" + } + }' + +echo -e "\n\nAll tests completed!" +echo "Check the MCPHawk web UI at http://localhost:8000 to see:" +echo "- All requests marked with purple 'MCP' badges" +echo "- Use the MCPHawk toggle button to filter these messages" +echo "- Click on messages to see full JSON details" \ No newline at end of file diff --git a/examples/test_mcp_sdk.sh b/examples/test_mcp_sdk.sh new file mode 100755 index 0000000..e5dfc38 --- /dev/null +++ b/examples/test_mcp_sdk.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Test MCPHawk MCP Server (SDK version) with proper session flow +# The SDK requires: +# 1. First initialize request WITHOUT session ID +# 2. Server returns session ID in response +# 3. Use that session ID for subsequent requests + +MCP_URL="http://localhost:8765/mcp" + +echo "Testing MCPHawk MCP Server (SDK version)" +echo "========================================" + +# 1. Initialize WITHOUT session ID - server will assign one +echo -e "\n1. Initializing MCP session (no session ID)..." +response=$(curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0" + } + }, + "id": 1 + }') + +echo "$response" + +# Extract session ID from response headers +SESSION_ID=$(echo "$response" | grep -i "mcp-session-id:" | sed 's/.*: //' | tr -d '\r\n') + +if [ -z "$SESSION_ID" ]; then + echo "ERROR: No session ID received from server" + exit 1 +fi + +echo -e "\nReceived session ID: $SESSION_ID" + +sleep 1 + +# 2. Now use the server-provided session ID for subsequent requests +echo -e "\n2. Listing tools with session ID..." +curl -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 2 + }' | jq . + +sleep 1 + +# 3. Test a notification (should return no response) +echo -e "\n3. Sending notification..." +curl -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { + "progress": 50, + "message": "Test progress notification" + } + }' + +echo -e "\n\nCheck http://localhost:8000 for captured traffic" \ No newline at end of file diff --git a/examples/test_mcp_sdk_sse.sh b/examples/test_mcp_sdk_sse.sh new file mode 100755 index 0000000..8be5813 --- /dev/null +++ b/examples/test_mcp_sdk_sse.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +# Test MCPHawk MCP Server (SDK version) with SSE response handling + +MCP_URL="http://localhost:8765/mcp" + +echo "Testing MCPHawk MCP Server (SDK version) with SSE" +echo "=================================================" + +# 1. Initialize WITHOUT session ID +echo -e "\n1. Initializing MCP session..." +response=$(curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0" + } + }, + "id": 1 + }') + +# Extract session ID +SESSION_ID=$(echo "$response" | grep -i "mcp-session-id:" | sed 's/.*: //' | tr -d '\r\n') +echo "Session ID: $SESSION_ID" + +# Extract JSON from SSE response +json_data=$(echo "$response" | grep "^data: " | sed 's/^data: //') +echo "Response JSON:" +echo "$json_data" | jq . + +sleep 1 + +# 2. List tools +echo -e "\n2. Listing tools..." +response=$(curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 2 + }') + +# Extract JSON from SSE +json_data=$(echo "$response" | grep "^data: " | sed 's/^data: //') +echo "$json_data" | jq . + +sleep 1 + +# 3. Call a tool +echo -e "\n3. Getting stats..." +response=$(curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "get_stats", + "arguments": {} + }, + "id": 3 + }') + +json_data=$(echo "$response" | grep "^data: " | sed 's/^data: //') +echo "$json_data" | jq . + +# 4. Test standard MCP notification (not custom) +echo -e "\n4. Sending standard initialized notification..." +curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }' | head -10 + +echo -e "\n\nNote: Responses use Server-Sent Events (SSE) format" +echo "The sniffer might not capture SSE responses properly" \ No newline at end of file diff --git a/examples/test_mcp_sdk_wait.sh b/examples/test_mcp_sdk_wait.sh new file mode 100755 index 0000000..ebeea64 --- /dev/null +++ b/examples/test_mcp_sdk_wait.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +# Test MCPHawk MCP Server with proper initialization wait + +MCP_URL="http://localhost:8765/mcp" + +echo "Testing MCPHawk MCP Server (SDK version)" +echo "========================================" + +# 1. Initialize +echo -e "\n1. Initializing MCP session..." +response=$(curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0" + } + }, + "id": 1 + }') + +SESSION_ID=$(echo "$response" | grep -i "mcp-session-id:" | sed 's/.*: //' | tr -d '\r\n') +echo "Session ID: $SESSION_ID" + +# 2. Send initialized notification (this might be required by SDK) +echo -e "\n2. Sending initialized notification to complete handshake..." +curl -s -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }' + +echo "Waiting for initialization to complete..." +sleep 2 + +# 3. Now try listing tools +echo -e "\n3. Listing tools..." +response=$(curl -s -i -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 2 + }') + +json_data=$(echo "$response" | grep "^data: " | sed 's/^data: //') +if [ -n "$json_data" ]; then + echo "$json_data" | jq . +else + echo "Raw response:" + echo "$response" | tail -20 +fi + +# 4. Test calling a tool +echo -e "\n4. Calling get_stats tool..." +response=$(curl -s -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "get_stats", + "arguments": {} + }, + "id": 3 + }') + +# Try to extract JSON - SDK might return plain JSON for errors +echo "$response" | jq . 2>/dev/null || echo "$response" \ No newline at end of file diff --git a/examples/test_mcp_simple.sh b/examples/test_mcp_simple.sh new file mode 100755 index 0000000..cb8c054 --- /dev/null +++ b/examples/test_mcp_simple.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Simple test for MCPHawk MCP Server +# Run MCPHawk with: sudo mcphawk web --auto-detect --with-mcp --mcp-transport http --mcp-port 8765 --debug + +SESSION_ID="test-$(date +%s)" +echo "Testing with session: $SESSION_ID" + +# Single test request +echo "Sending initialize request..." +curl -v -X POST http://localhost:8765/mcp \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0" + } + }, + "id": 1 + }' 2>&1 + +echo -e "\n\nCheck http://localhost:8000 for captured traffic" \ No newline at end of file diff --git a/frontend/src/components/LogTable/LogFilters.vue b/frontend/src/components/LogTable/LogFilters.vue index 8bc7ac6..00e34e3 100644 --- a/frontend/src/components/LogTable/LogFilters.vue +++ b/frontend/src/components/LogTable/LogFilters.vue @@ -48,6 +48,21 @@ {{ logStore.expandAll ? 'Collapse' : 'Expand' }} + + + + MCPHawk + + import { computed, ref, watch } from 'vue' import { useLogStore } from '@/stores/logs' -import { MagnifyingGlassIcon, TrashIcon, ArrowPathIcon, CodeBracketIcon } from '@heroicons/vue/24/outline' +import { MagnifyingGlassIcon, TrashIcon, ArrowPathIcon, CodeBracketIcon, FunnelIcon } from '@heroicons/vue/24/outline' const logStore = useLogStore() diff --git a/frontend/src/components/LogTable/LogRow.vue b/frontend/src/components/LogTable/LogRow.vue index 000beeb..446eb24 100644 --- a/frontend/src/components/LogTable/LogRow.vue +++ b/frontend/src/components/LogTable/LogRow.vue @@ -17,14 +17,26 @@ {{ formatTimestamp(log.timestamp) }} - - + + + + + MCPπ¦ + + {{ messageSummary }} - {{ log.traffic_type || 'N/A' }} + + {{ formattedTransportType }} + @@ -59,6 +71,7 @@ \ No newline at end of file diff --git a/frontend/src/components/LogTable/LogTable.vue b/frontend/src/components/LogTable/LogTable.vue index 995eb0b..86bef8f 100644 --- a/frontend/src/components/LogTable/LogTable.vue +++ b/frontend/src/components/LogTable/LogTable.vue @@ -16,14 +16,14 @@ Time - + Type Message - Traffic + Transport Source β Dest @@ -36,10 +36,10 @@ @@ -75,10 +75,6 @@ const logStore = useLogStore() const displayLogs = computed(() => logStore.filteredLogs) function handleLogClick(log) { - // Generate ID if missing (for compatibility) - if (!log.id) { - log.id = `${log.timestamp}-${log.src_port}-${Math.random()}` - } - logStore.selectLog(log.id) + logStore.selectLog(log.log_id) } \ No newline at end of file diff --git a/frontend/src/components/MessageDetail/MessageDetailModal.vue b/frontend/src/components/MessageDetail/MessageDetailModal.vue index c08ef56..64728ef 100644 --- a/frontend/src/components/MessageDetail/MessageDetailModal.vue +++ b/frontend/src/components/MessageDetail/MessageDetailModal.vue @@ -41,11 +41,19 @@ Type: + + MCPπ¦ + - Traffic: - - {{ logStore.selectedLog.traffic_type || 'N/A' }} + Transport: + + {{ formattedTransportType }} @@ -118,6 +126,7 @@ import { Dialog, DialogPanel, DialogTitle, TransitionChild, TransitionRoot } fro import { ClipboardDocumentIcon } from '@heroicons/vue/24/outline' import { useLogStore } from '@/stores/logs' import { getMessageType, parseMessage } from '@/utils/messageParser' +import { formatTransportType, getTransportTypeColor } from '@/utils/transportFormatter' import MessageTypeBadge from '@/components/LogTable/MessageTypeBadge.vue' import PairedMessages from '@/components/common/PairedMessages.vue' @@ -135,6 +144,26 @@ const formattedJson = computed(() => { return JSON.stringify(parsed, null, 2) }) +const isMcpHawkTraffic = computed(() => { + if (!logStore.selectedLog?.metadata) return false + try { + const meta = JSON.parse(logStore.selectedLog.metadata) + return meta.source === 'mcphawk-mcp' + } catch { + return false + } +}) + +const formattedTransportType = computed(() => { + if (!logStore.selectedLog) return 'Unknown' + return formatTransportType(logStore.selectedLog.transport_type || logStore.selectedLog.traffic_type || 'unknown') +}) + +const transportTypeColor = computed(() => { + if (!logStore.selectedLog) return '' + return getTransportTypeColor(logStore.selectedLog.transport_type || logStore.selectedLog.traffic_type || 'unknown') +}) + async function copyToClipboard() { try { await navigator.clipboard.writeText(formattedJson.value) diff --git a/frontend/src/components/common/ConnectionStatus.vue b/frontend/src/components/common/ConnectionStatus.vue index 96ff22d..da87116 100644 --- a/frontend/src/components/common/ConnectionStatus.vue +++ b/frontend/src/components/common/ConnectionStatus.vue @@ -1,22 +1,48 @@ - - - - {{ statusText }} - + + + + + + {{ wsStatusText }} + + + + + + + + MCP Server + + \ No newline at end of file diff --git a/frontend/src/stores/logs.js b/frontend/src/stores/logs.js index ece3888..6266d5b 100644 --- a/frontend/src/stores/logs.js +++ b/frontend/src/stores/logs.js @@ -13,11 +13,25 @@ export const useLogStore = defineStore('logs', () => { const loading = ref(false) const error = ref(null) const expandAll = ref(false) + const showMcpHawkTraffic = ref(false) // Computed const filteredLogs = computed(() => { let result = logs.value + // Filter out MCPHawk's own traffic if toggle is off + if (!showMcpHawkTraffic.value) { + result = result.filter(log => { + if (!log.metadata) return true + try { + const meta = JSON.parse(log.metadata) + return meta.source !== 'mcphawk-mcp' + } catch { + return true + } + }) + } + // Apply type filter if (filter.value !== 'all') { result = result.filter(log => { @@ -43,7 +57,8 @@ export const useLogStore = defineStore('logs', () => { requests: 0, responses: 0, notifications: 0, - errors: 0 + errors: 0, + mcphawk: 0 } logs.value.forEach(log => { @@ -52,13 +67,23 @@ export const useLogStore = defineStore('logs', () => { else if (msgType === 'response') stats.responses++ else if (msgType === 'notification') stats.notifications++ else if (msgType === 'error') stats.errors++ + + // Count MCPHawk's own traffic + if (log.metadata) { + try { + const meta = JSON.parse(log.metadata) + if (meta.source === 'mcphawk-mcp') stats.mcphawk++ + } catch { + // ignore parse errors + } + } }) return stats }) const selectedLog = computed(() => { - return logs.value.find(log => log.id === selectedLogId.value) + return logs.value.find(log => log.log_id === selectedLogId.value) }) const pairedLogs = computed(() => { @@ -71,7 +96,7 @@ export const useLogStore = defineStore('logs', () => { logs.value.forEach(log => { const logParsed = parseMessage(log.message) if (logParsed && logParsed.id === parsed.id) { - paired.add(log.id) + paired.add(log.log_id) } }) @@ -131,6 +156,10 @@ export const useLogStore = defineStore('logs', () => { expandAll.value = !expandAll.value } + function toggleMcpHawkTraffic() { + showMcpHawkTraffic.value = !showMcpHawkTraffic.value + } + return { // State logs, @@ -141,6 +170,7 @@ export const useLogStore = defineStore('logs', () => { loading, error, expandAll, + showMcpHawkTraffic, // Computed filteredLogs, @@ -156,6 +186,7 @@ export const useLogStore = defineStore('logs', () => { setFilter, setSearchQuery, togglePairing, - toggleExpandAll + toggleExpandAll, + toggleMcpHawkTraffic } }) \ No newline at end of file diff --git a/frontend/src/utils/transportFormatter.js b/frontend/src/utils/transportFormatter.js new file mode 100644 index 0000000..e847b7b --- /dev/null +++ b/frontend/src/utils/transportFormatter.js @@ -0,0 +1,30 @@ +/** + * Format transport type for display + */ +export function formatTransportType(transportType) { + const transportMap = { + 'streamable_http': 'Streamable HTTP', + 'http_sse': 'HTTP+SSE', + 'stdio': 'stdio', + 'unknown': 'Unknown', + // Legacy values + 'TCP/Direct': 'Unknown', + 'N/A': 'Unknown' + } + + return transportMap[transportType] || 'Unknown' +} + +/** + * Get transport type badge color + */ +export function getTransportTypeColor(transportType) { + const colorMap = { + 'streamable_http': 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200', + 'http_sse': 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200', + 'stdio': 'bg-purple-100 text-purple-800 dark:bg-purple-900 dark:text-purple-200', + 'unknown': 'bg-gray-100 text-gray-800 dark:bg-gray-900 dark:text-gray-200' + } + + return colorMap[transportType] || colorMap['unknown'] +} \ No newline at end of file diff --git a/mcphawk/__init__.py b/mcphawk/__init__.py index 3dc1f76..e69de29 100644 --- a/mcphawk/__init__.py +++ b/mcphawk/__init__.py @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/mcphawk/cli.py b/mcphawk/cli.py index 7a13a5c..fac3500 100644 --- a/mcphawk/cli.py +++ b/mcphawk/cli.py @@ -1,15 +1,21 @@ +import asyncio import logging import sys +import threading import typer from mcphawk.logger import init_db +from mcphawk.mcp_server.server import MCPHawkServer from mcphawk.sniffer import start_sniffer from mcphawk.web.server import run_web # Suppress Scapy warnings about network interfaces logging.getLogger("scapy.runtime").setLevel(logging.ERROR) +# Setup logger for CLI +logger = logging.getLogger("mcphawk.cli") + # β Typer multi-command app app = typer.Typer(help="MCPHawk: Passive MCP traffic sniffer + dashboard") @@ -22,16 +28,26 @@ def sniff( port: int = typer.Option(None, "--port", "-p", help="TCP port to monitor"), filter: str = typer.Option(None, "--filter", "-f", help="Custom BPF filter expression"), auto_detect: bool = typer.Option(False, "--auto-detect", "-a", help="Auto-detect MCP traffic on any port"), + with_mcp: bool = typer.Option(False, "--with-mcp", help="Start MCP server alongside sniffer"), + mcp_transport: str = typer.Option("http", "--mcp-transport", help="MCP transport type: stdio or http"), + mcp_port: int = typer.Option(8765, "--mcp-port", help="Port for MCP HTTP server (ignored for stdio)"), debug: bool = typer.Option(False, "--debug", "-d", help="Enable debug output") ): """Start sniffing MCP traffic (console output only).""" + # Configure logging - clear existing handlers first + logger.handlers.clear() + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter('[MCPHawk] %(message)s')) + logger.addHandler(handler) + logger.setLevel(logging.DEBUG if debug else logging.INFO) + # Validate that user specified either port, filter, or auto-detect if not any([port, filter, auto_detect]): - print("[ERROR] You must specify either --port, --filter, or --auto-detect") - print("Examples:") - print(" mcphawk sniff --port 3000") - print(" mcphawk sniff --filter 'tcp port 3000 or tcp port 3001'") - print(" mcphawk sniff --auto-detect") + logger.error("You must specify either --port, --filter, or --auto-detect") + logger.error("Examples:") + logger.error(" mcphawk sniff --port 3000") + logger.error(" mcphawk sniff --filter 'tcp port 3000 or tcp port 3001'") + logger.error(" mcphawk sniff --auto-detect") raise typer.Exit(1) if filter: @@ -43,14 +59,48 @@ def sniff( else: # Auto-detect mode - capture all TCP traffic filter_expr = "tcp" - print("[MCPHawk] Auto-detect mode: monitoring all TCP traffic for MCP messages") + logger.info("Auto-detect mode: monitoring all TCP traffic for MCP messages") + + # Start MCP server if requested + mcp_thread = None + excluded_ports = [] + mcphawk_mcp_ports = [] + if with_mcp: + server = MCPHawkServer() + + if mcp_transport == "http": + logger.info(f"Starting MCP HTTP server on http://localhost:{mcp_port}/mcp") + # Only exclude MCP port if not in debug mode + if not debug: + excluded_ports = [mcp_port] + else: + logger.info("Debug mode: HTTP MCP traffic will be captured and tagged") + mcphawk_mcp_ports = [mcp_port] + def run_mcp(): + asyncio.run(server.run_http(port=mcp_port)) + else: + logger.info("Starting MCP server on stdio (configure in your MCP client)") + if debug: + logger.info("Note: stdio MCP traffic cannot be captured (use HTTP transport for debugging)") + def run_mcp(): + asyncio.run(server.run_stdio()) + + mcp_thread = threading.Thread(target=run_mcp, daemon=True) + mcp_thread.start() + + logger.info(f"Starting sniffer with filter: {filter_expr}") + logger.info("Press Ctrl+C to stop...") - print(f"[MCPHawk] Starting sniffer with filter: {filter_expr}") - print("[MCPHawk] Press Ctrl+C to stop...") try: - start_sniffer(filter_expr=filter_expr, auto_detect=auto_detect, debug=debug) + start_sniffer( + filter_expr=filter_expr, + auto_detect=auto_detect, + debug=debug, + excluded_ports=excluded_ports, + mcphawk_mcp_ports=mcphawk_mcp_ports + ) except KeyboardInterrupt: - print("\n[MCPHawk] Sniffer stopped.") + logger.info("Sniffer stopped.") sys.exit(0) @@ -62,17 +112,27 @@ def web( no_sniffer: bool = typer.Option(False, "--no-sniffer", help="Disable sniffer (view historical logs only)"), host: str = typer.Option("127.0.0.1", "--host", help="Web server host"), web_port: int = typer.Option(8000, "--web-port", help="Web server port"), + with_mcp: bool = typer.Option(False, "--with-mcp", help="Start MCP server alongside web UI"), + mcp_transport: str = typer.Option("http", "--mcp-transport", help="MCP transport type: stdio or http"), + mcp_port: int = typer.Option(8765, "--mcp-port", help="Port for MCP HTTP server (ignored for stdio)"), debug: bool = typer.Option(False, "--debug", "-d", help="Enable debug output") ): """Start the MCPHawk dashboard with sniffer.""" + # Configure logging - clear existing handlers first + logger.handlers.clear() + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter('[MCPHawk] %(message)s')) + logger.addHandler(handler) + logger.setLevel(logging.DEBUG if debug else logging.INFO) + # If sniffer is enabled, validate that user specified either port, filter, or auto-detect if not no_sniffer and not any([port, filter, auto_detect]): - print("[ERROR] You must specify either --port, --filter, or --auto-detect (or use --no-sniffer)") - print("Examples:") - print(" mcphawk web --port 3000") - print(" mcphawk web --filter 'tcp port 3000 or tcp port 3001'") - print(" mcphawk web --auto-detect") - print(" mcphawk web --no-sniffer # View historical logs only") + logger.error("You must specify either --port, --filter, or --auto-detect (or use --no-sniffer)") + logger.error("Examples:") + logger.error(" mcphawk web --port 3000") + logger.error(" mcphawk web --filter 'tcp port 3000 or tcp port 3001'") + logger.error(" mcphawk web --auto-detect") + logger.error(" mcphawk web --no-sniffer # View historical logs only") raise typer.Exit(1) # Prepare filter expression @@ -85,11 +145,101 @@ def web( else: filter_expr = None # No sniffer + # Start MCP server if requested + mcp_thread = None + excluded_ports = [] + mcphawk_mcp_ports = [] + if with_mcp: + server = MCPHawkServer() + + if mcp_transport == "http": + logger.info(f"Starting MCP HTTP server on http://localhost:{mcp_port}/mcp") + # Only exclude MCP port if not in debug mode + if not debug: + excluded_ports = [mcp_port] + else: + logger.info("Debug mode: HTTP MCP traffic will be captured and tagged") + mcphawk_mcp_ports = [mcp_port] + def run_mcp(): + asyncio.run(server.run_http(port=mcp_port)) + else: + logger.info("Starting MCP server on stdio (configure in your MCP client)") + if debug: + logger.info("Note: stdio MCP traffic cannot be captured (use HTTP transport for debugging)") + def run_mcp(): + asyncio.run(server.run_stdio()) + + mcp_thread = threading.Thread(target=run_mcp, daemon=True) + mcp_thread.start() + run_web( sniffer=not no_sniffer, host=host, port=web_port, filter_expr=filter_expr, auto_detect=auto_detect, - debug=debug + debug=debug, + excluded_ports=excluded_ports, + with_mcp=with_mcp, + mcphawk_mcp_ports=mcphawk_mcp_ports ) + + +@app.command() +def mcp( + transport: str = typer.Option("stdio", "--transport", "-t", help="Transport type: stdio or tcp"), + mcp_port: int = typer.Option(8765, "--mcp-port", help="Port for TCP transport (ignored for stdio)"), + debug: bool = typer.Option(False, "--debug", "-d", help="Enable debug output") +): + """Run MCPHawk MCP server standalone (query existing captured data).""" + # Configure logging based on transport and debug flag - clear existing handlers first + logger.handlers.clear() + if transport == "stdio": + # For stdio, all logs must go to stderr to avoid interfering with JSON-RPC on stdout + handler = logging.StreamHandler(sys.stderr) + else: + # For other transports, use stdout + handler = logging.StreamHandler(sys.stdout) + + handler.setFormatter(logging.Formatter('[MCPHawk] %(message)s')) + logger.addHandler(handler) + logger.setLevel(logging.DEBUG if debug else logging.INFO) + + logger.info(f"Starting MCP server (transport: {transport})") + + if transport == "stdio": + logger.debug("Use this server with MCP clients by configuring stdio transport") + logger.debug("Example MCP client configuration:") + logger.debug(""" +{ + "mcpServers": { + "mcphawk": { + "command": "mcphawk", + "args": ["mcp"] + } + } +} + """) + elif transport == "http": + logger.info(f"MCP server will listen on http://localhost:{mcp_port}/mcp") + logger.debug("Example test command:") + logger.debug(f"curl -X POST http://localhost:{mcp_port}/mcp -H 'Content-Type: application/json' -d '{{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"params\":{{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{{}},\"clientInfo\":{{\"name\":\"test\",\"version\":\"1.0\"}}}},\"id\":1}}'") + else: + logger.error(f"Unknown transport: {transport}") + raise typer.Exit(1) + + # Create and run MCP server + server = MCPHawkServer() + + try: + if transport == "stdio": + asyncio.run(server.run_stdio()) + elif transport == "http": + asyncio.run(server.run_http(port=mcp_port)) + except KeyboardInterrupt: + logger.info("MCP server stopped.") + sys.exit(0) + + +if __name__ == "__main__": + app() diff --git a/mcphawk/logger.py b/mcphawk/logger.py index 1ba1c78..1b5a583 100644 --- a/mcphawk/logger.py +++ b/mcphawk/logger.py @@ -29,7 +29,7 @@ def init_db() -> None: cur.execute( """ CREATE TABLE IF NOT EXISTS logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + log_id TEXT PRIMARY KEY, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, src_ip TEXT, dst_ip TEXT, @@ -37,16 +37,11 @@ def init_db() -> None: dst_port INTEGER, direction TEXT CHECK(direction IN ('incoming', 'outgoing', 'unknown')), message TEXT, - traffic_type TEXT + transport_type TEXT, + metadata TEXT ) """ ) - - # Add traffic_type column to existing tables - cur.execute("PRAGMA table_info(logs)") - columns = [col[1] for col in cur.fetchall()] - if "traffic_type" not in columns: - cur.execute("ALTER TABLE logs ADD COLUMN traffic_type TEXT") conn.commit() conn.close() @@ -57,6 +52,7 @@ def log_message(entry: dict[str, Any]) -> None: Args: entry (Dict[str, Any]): Must contain MCPMessageLog fields: + log_id (str) - UUID for the log entry timestamp (datetime) - If missing, current time is used src_ip (str) dst_ip (str) @@ -64,17 +60,23 @@ def log_message(entry: dict[str, Any]) -> None: dst_port (int) direction (str): 'incoming', 'outgoing', or 'unknown' message (str) - traffic_type (str): 'TCP', 'WS', or 'N/A' (optional, defaults to 'N/A') + transport_type (str): 'streamable_http', 'http_sse', 'stdio', or 'unknown' (optional, defaults to 'unknown') + metadata (str): JSON string with additional metadata (optional) """ timestamp = entry.get("timestamp", datetime.now(tz=timezone.utc)) + log_id = entry.get("log_id") + if not log_id: + raise ValueError("log_id is required") + conn = sqlite3.connect(DB_PATH) cur = conn.cursor() cur.execute( """ - INSERT INTO logs (timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, traffic_type) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( + log_id, timestamp.isoformat(), entry.get("src_ip"), entry.get("dst_ip"), @@ -82,7 +84,8 @@ def log_message(entry: dict[str, Any]) -> None: entry.get("dst_port"), entry.get("direction", "unknown"), entry.get("message"), - entry.get("traffic_type", "N/A"), + entry.get("transport_type", "unknown"), + entry.get("metadata"), ), ) conn.commit() @@ -113,9 +116,9 @@ def fetch_logs(limit: int = 100) -> list[dict[str, Any]]: cur = conn.cursor() cur.execute( """ - SELECT timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, traffic_type + SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata FROM logs - ORDER BY id DESC + ORDER BY timestamp DESC LIMIT ? """, (limit,), @@ -125,6 +128,7 @@ def fetch_logs(limit: int = 100) -> list[dict[str, Any]]: return [ { + "log_id": row["log_id"], "timestamp": datetime.fromisoformat(row["timestamp"]), "src_ip": row["src_ip"], "dst_ip": row["dst_ip"], @@ -132,7 +136,8 @@ def fetch_logs(limit: int = 100) -> list[dict[str, Any]]: "dst_port": row["dst_port"], "direction": row["direction"], "message": row["message"], - "traffic_type": row["traffic_type"] if row["traffic_type"] is not None else "N/A", + "transport_type": row["transport_type"] if row["transport_type"] is not None else "unknown", + "metadata": row["metadata"], } for row in rows ] @@ -162,3 +167,255 @@ def clear_logs() -> None: cur.execute("DELETE FROM logs;") conn.commit() conn.close() + + +def get_log_by_id(log_id: str) -> dict[str, Any] | None: + """ + Fetch a specific log entry by ID. + + Args: + log_id (str): The UUID of the log entry to retrieve. + + Returns: + Dictionary matching MCPMessageLog format or None if not found. + """ + current_path = DB_PATH if DB_PATH else _DEFAULT_DB_PATH + if not current_path.exists(): + return None + + conn = sqlite3.connect(current_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + cur.execute( + """ + SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata + FROM logs + WHERE log_id = ? + """, + (log_id,), + ) + row = cur.fetchone() + conn.close() + + if not row: + return None + + return { + "log_id": row["log_id"], + "timestamp": datetime.fromisoformat(row["timestamp"]), + "src_ip": row["src_ip"], + "dst_ip": row["dst_ip"], + "src_port": row["src_port"], + "dst_port": row["dst_port"], + "direction": row["direction"], + "message": row["message"], + "transport_type": row["transport_type"] if row["transport_type"] is not None else "unknown", + "metadata": row["metadata"], + } + + +def fetch_logs_with_offset(limit: int = 100, offset: int = 0) -> list[dict[str, Any]]: + """ + Fetch logs with limit and offset for pagination. + + Args: + limit: Maximum number of logs to return + offset: Number of logs to skip + + Returns: + List of dictionaries matching MCPMessageLog format. + """ + current_path = DB_PATH if DB_PATH else _DEFAULT_DB_PATH + if not current_path.exists(): + return [] + + conn = sqlite3.connect(current_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + cur.execute( + """ + SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata + FROM logs + ORDER BY log_id DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ) + rows = cur.fetchall() + conn.close() + + return [ + { + "log_id": row["log_id"], + "timestamp": datetime.fromisoformat(row["timestamp"]), + "src_ip": row["src_ip"], + "dst_ip": row["dst_ip"], + "src_port": row["src_port"], + "dst_port": row["dst_port"], + "direction": row["direction"], + "message": row["message"], + "transport_type": row["transport_type"] if row["transport_type"] is not None else "unknown", + "metadata": row["metadata"], + } + for row in rows + ] + + +def search_logs(search_term: str = "", message_type: str | None = None, + transport_type: str | None = None, limit: int = 100) -> list[dict[str, Any]]: + """ + Search logs by various criteria. + + Args: + search_term: Text to search for in messages + message_type: Filter by JSON-RPC message type (request, response, notification) + transport_type: Filter by transport type (streamable_http, http_sse, stdio, unknown) + limit: Maximum number of results + + Returns: + List of matching log entries. + """ + current_path = DB_PATH if DB_PATH else _DEFAULT_DB_PATH + if not current_path.exists(): + return [] + + conn = sqlite3.connect(current_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + query = "SELECT * FROM logs WHERE 1=1" + params = [] + + if search_term: + query += " AND message LIKE ?" + params.append(f"%{search_term}%") + + if transport_type: + query += " AND transport_type = ?" + params.append(transport_type) + + query += " ORDER BY log_id DESC LIMIT ?" + params.append(limit) + + cur.execute(query, params) + rows = cur.fetchall() + conn.close() + + # Filter by message type if specified + results = [] + for row in rows: + log_dict = { + "log_id": row["log_id"], + "timestamp": datetime.fromisoformat(row["timestamp"]), + "src_ip": row["src_ip"], + "dst_ip": row["dst_ip"], + "src_port": row["src_port"], + "dst_port": row["dst_port"], + "direction": row["direction"], + "message": row["message"], + "transport_type": row["transport_type"] if row["transport_type"] is not None else "unknown", + "metadata": row["metadata"], + } + + # If message_type filter is specified, check it + if message_type: + from .utils import get_message_type + if get_message_type(row["message"]) != message_type: + continue + + results.append(log_dict) + + return results + + +def get_traffic_stats() -> dict[str, Any]: + """ + Get statistics about captured traffic. + + Returns: + Dictionary with traffic statistics. + """ + current_path = DB_PATH if DB_PATH else _DEFAULT_DB_PATH + if not current_path.exists(): + return { + "total_logs": 0, + "requests": 0, + "responses": 0, + "notifications": 0, + "errors": 0, + "by_transport_type": {} + } + + conn = sqlite3.connect(current_path) + cur = conn.cursor() + + # Get all messages for analysis + cur.execute("SELECT message, transport_type FROM logs") + logs = cur.fetchall() + + stats = { + "total_logs": len(logs), + "requests": 0, + "responses": 0, + "notifications": 0, + "errors": 0, + "by_transport_type": {} + } + + from .utils import get_message_type + + for message, transport_type in logs: + # Count by message type + msg_type = get_message_type(message) + if msg_type == "request": + stats["requests"] += 1 + elif msg_type == "response": + stats["responses"] += 1 + elif msg_type == "notification": + stats["notifications"] += 1 + + # Check for errors + try: + import json + msg_data = json.loads(message) + if "error" in msg_data: + stats["errors"] += 1 + except Exception: + pass + + # Count by transport type + if transport_type: + stats["by_transport_type"][transport_type] = stats["by_transport_type"].get(transport_type, 0) + 1 + + conn.close() + return stats + + +def get_unique_methods() -> list[str]: + """ + Get all unique JSON-RPC methods from captured traffic. + + Returns: + Sorted list of unique method names. + """ + current_path = DB_PATH if DB_PATH else _DEFAULT_DB_PATH + if not current_path.exists(): + return [] + + conn = sqlite3.connect(current_path) + cur = conn.cursor() + cur.execute("SELECT message FROM logs") + logs = cur.fetchall() + conn.close() + + methods = set() + for (message,) in logs: + try: + import json + msg_data = json.loads(message) + if "method" in msg_data: + methods.add(msg_data["method"]) + except Exception: + pass + + return sorted(methods) diff --git a/mcphawk/mcp_server/__init__.py b/mcphawk/mcp_server/__init__.py new file mode 100644 index 0000000..6e3a591 --- /dev/null +++ b/mcphawk/mcp_server/__init__.py @@ -0,0 +1 @@ +"""MCPHawk MCP Server - Query and analyze captured MCP traffic.""" diff --git a/mcphawk/mcp_server/server.py b/mcphawk/mcp_server/server.py new file mode 100644 index 0000000..de1530b --- /dev/null +++ b/mcphawk/mcp_server/server.py @@ -0,0 +1,122 @@ +"""MCP server implementation using SDK's built-in HTTP transport.""" + +import json +import logging +from typing import Optional + +from mcp.server.fastmcp import FastMCP + +from .. import logger as mcphawk_logger + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MCPHawkServer: + """MCP server using SDK's HTTP transport.""" + + def __init__(self, db_path: Optional[str] = None, host: str = "127.0.0.1", port: int = 8765): + # Store configuration + self.http_host = host + self.http_port = port + + # FastMCP accepts host and port in constructor + self.mcp = FastMCP("mcphawk-mcp", host=host, port=port) + if db_path: + mcphawk_logger.set_db_path(db_path) + self._setup_handlers() + + def _setup_handlers(self): + """Setup MCP protocol handlers.""" + + @self.mcp.tool() + async def query_traffic(limit: int = 100, offset: int = 0) -> str: + """Query captured MCP traffic with optional limit and offset.""" + logs = mcphawk_logger.fetch_logs_with_offset(limit=limit, offset=offset) + + # Convert timestamps to ISO format for JSON serialization + for log in logs: + if log.get("timestamp"): + log["timestamp"] = log["timestamp"].isoformat() + + return json.dumps(logs, indent=2) + + @self.mcp.tool() + async def get_log(log_id: str) -> str: + """Get a specific log entry by ID.""" + log = mcphawk_logger.get_log_by_id(log_id) + + if not log: + return f"No log found with ID: {log_id}" + + # Convert timestamp to ISO format + if log.get("timestamp"): + log["timestamp"] = log["timestamp"].isoformat() + + return json.dumps(log, indent=2) + + @self.mcp.tool() + async def search_traffic( + search_term: str = "", + message_type: Optional[str] = None, + transport_type: Optional[str] = None, + limit: int = 100 + ) -> str: + """Search traffic by message content or type. + + Args: + search_term: Term to search for in message content + message_type: Filter by message type (request, response, notification) + transport_type: Filter by transport type (streamable_http/http_sse/stdio/unknown) + limit: Maximum number of results + """ + logs = mcphawk_logger.search_logs( + search_term=search_term, + message_type=message_type, + transport_type=transport_type, + limit=limit + ) + + # Convert timestamps to ISO format for JSON serialization + for log in logs: + if log.get("timestamp"): + log["timestamp"] = log["timestamp"].isoformat() + + return json.dumps(logs, indent=2) + + @self.mcp.tool() + async def get_stats() -> str: + """Get statistics about captured traffic.""" + stats = mcphawk_logger.get_traffic_stats() + return json.dumps(stats, indent=2) + + @self.mcp.tool() + async def list_methods() -> str: + """List all unique JSON-RPC methods seen in traffic.""" + methods = mcphawk_logger.get_unique_methods() + + result = { + "methods": methods, + "count": len(methods) + } + + return json.dumps(result, indent=2) + + async def run_stdio(self): + """Run the MCP server using stdio transport.""" + await self.mcp.run_stdio_async() + + async def run_http(self, host: str = "127.0.0.1", port: int = 8765): + """Run the MCP server using SDK's Streamable HTTP transport.""" + # If different host/port specified, we need to recreate FastMCP + if host != self.http_host or port != self.http_port: + self.http_host = host + self.http_port = port + # Recreate FastMCP with new settings + self.mcp = FastMCP("mcphawk-mcp", host=host, port=port) + self._setup_handlers() + + # The SDK handles all the HTTP server setup internally + await self.mcp.run_streamable_http_async() + diff --git a/mcphawk/sniffer.py b/mcphawk/sniffer.py index f25f107..41c4e56 100644 --- a/mcphawk/sniffer.py +++ b/mcphawk/sniffer.py @@ -1,6 +1,7 @@ import asyncio import logging import platform +import uuid from datetime import datetime, timezone # Suppress Scapy warnings before importing @@ -9,8 +10,8 @@ from scapy.all import IP, TCP, IPv6, Raw, conf, sniff # noqa: E402 from mcphawk.logger import log_message # noqa: E402 +from mcphawk.tcp_reassembly import TCPStreamReassembler # noqa: E402 from mcphawk.web.broadcaster import broadcast_new_log # noqa: E402 -from mcphawk.ws_reassembly import process_ws_packet # noqa: E402 # Set up logger for this module logger = logging.getLogger(__name__) @@ -34,106 +35,141 @@ def _broadcast_in_any_loop(log_entry: dict) -> None: # Global variable to track auto-detect mode _auto_detect_mode = False -# Track established WebSocket connections -_ws_connections = set() +# Global variable to track excluded ports +_excluded_ports = set() + +# Global variable to track MCPHawk's own MCP server ports for metadata tagging +_mcphawk_mcp_ports = set() + +# TCP stream reassembler for handling SSE and chunked responses +_tcp_reassembler = TCPStreamReassembler() +logger.info("TCP stream reassembler initialized") def packet_callback(pkt): """ Callback for every sniffed packet. Extract JSON-RPC messages from raw TCP payloads. """ + # Skip packets to/from excluded ports + if pkt.haslayer(TCP) and _excluded_ports: + tcp_layer = pkt[TCP] + if tcp_layer.sport in _excluded_ports or tcp_layer.dport in _excluded_ports: + return + + # Try TCP stream reassembly first for SSE/chunked responses + if pkt.haslayer(TCP) and pkt.haslayer(Raw): + tcp = pkt[TCP] + if tcp.sport == 8765 or tcp.dport == 8765: + raw_data = bytes(pkt[Raw]) + logger.debug(f"TCP packet for reassembly: {tcp.sport}->{tcp.dport}, {len(raw_data)} bytes") + # Log first 100 bytes to see what we're getting + preview = raw_data[:100].decode('utf-8', errors='replace') + logger.debug(f"Packet preview: {preview!r}") + + reassembled_messages = _tcp_reassembler.process_packet(pkt) + if reassembled_messages: + logger.info(f"TCP reassembler found {len(reassembled_messages)} messages!") + elif pkt.haslayer(TCP) and pkt.haslayer(Raw): + tcp = pkt[TCP] + if tcp.sport == 8765 or tcp.dport == 8765: + logger.debug(f"TCP reassembler returned no messages for {tcp.sport}->{tcp.dport}") + for msg_info in reassembled_messages: + # Process reassembled SSE messages + if "jsonrpc" in msg_info["message"]: + logger.debug(f"Reassembled {msg_info['type']}: {msg_info['message'][:100]}...") + + ts = datetime.now(tz=timezone.utc) + log_id = str(uuid.uuid4()) + + # Use transport type from TCP reassembler + transport = msg_info.get("transport", "unknown") + + entry = { + "log_id": log_id, + "timestamp": ts, + "src_ip": msg_info["src_ip"], + "src_port": msg_info["src_port"], + "dst_ip": msg_info["dst_ip"], + "dst_port": msg_info["dst_port"], + "direction": "unknown", + "message": msg_info["message"], + "transport_type": transport, + } + + # Add metadata if this is MCPHawk's own MCP traffic + if msg_info["src_port"] in _mcphawk_mcp_ports or msg_info["dst_port"] in _mcphawk_mcp_ports: + entry["metadata"] = '{"source": "mcphawk-mcp"}' + + log_message(entry) + + # Convert timestamp to ISO only for WebSocket broadcast + broadcast_entry = dict(entry) + broadcast_entry["timestamp"] = ts.isoformat() + _broadcast_in_any_loop(broadcast_entry) + + # In auto-detect mode, log when we find MCP traffic + if _auto_detect_mode: + transport_name = { + "streamable_http": "Streamable HTTP", + "http_sse": "HTTP+SSE", + "stdio": "stdio", + "unknown": "Unknown" + }.get(transport, "Unknown") + print(f"[MCPHawk] Detected {transport_name} MCP traffic on port {msg_info['src_port']} -> {msg_info['dst_port']}") + if pkt.haslayer(Raw): raw_payload = pkt[Raw].load + + # Check if this looks like SSE data + if raw_payload.startswith(b"data: "): + logger.debug(f"SSE data packet detected: {raw_payload[:100]}...") + if not _auto_detect_mode: # Less verbose in auto-detect mode logger.debug(f"Raw payload: {raw_payload[:60]}...") - # First, try to process as WebSocket traffic - if pkt.haslayer(TCP) and (pkt.haslayer(IP) or pkt.haslayer(IPv6)): - # Get IP addresses (IPv4 or IPv6) - if pkt.haslayer(IP): - src_ip = pkt[IP].src - dst_ip = pkt[IP].dst - else: # IPv6 - src_ip = pkt[IPv6].src - dst_ip = pkt[IPv6].dst - - src_port = pkt[TCP].sport - dst_port = pkt[TCP].dport - - # Check if this might be WebSocket traffic - is_ws_frame = False - is_http_upgrade = False - - if len(raw_payload) > 0: - first_byte = raw_payload[0] - # Check for WebSocket frames (masked or unmasked) - # Valid first bytes: 0x80-0x8F (FIN=1) or 0x00-0x0F (FIN=0) - # With common opcodes: 0x1 (text), 0x2 (binary), 0x8 (close), 0x9 (ping), 0xa (pong) - # Masked client frames: 0x81 -> 0xC1, 0x82 -> 0xC2, etc. - is_ws_frame = ( - (0x80 <= first_byte <= 0x8F) or # Unmasked frames - (0x00 <= first_byte <= 0x0F) or # Fragmented frames - (0xC0 <= first_byte <= 0xCF) # Masked frames (common case) - ) - - # Also check for HTTP upgrade - is_http_upgrade = ( - raw_payload[:4] == b'HTTP' or - raw_payload[:3] == b'GET' or - b'Upgrade: websocket' in raw_payload - ) - - # Check if this is a known WebSocket connection - conn_key = (src_ip, src_port, dst_ip, dst_port) - reverse_key = (dst_ip, dst_port, src_ip, src_port) - is_known_ws = conn_key in _ws_connections or reverse_key in _ws_connections - - if is_ws_frame or is_http_upgrade or is_known_ws: - logger.debug(f"Detected WebSocket traffic: is_frame={is_ws_frame}, is_http={is_http_upgrade}, is_known={is_known_ws}, first_byte={hex(first_byte) if len(raw_payload) > 0 else 'N/A'}") - - # Mark this as a WebSocket connection - if is_ws_frame or is_http_upgrade: - _ws_connections.add(conn_key) - _ws_connections.add(reverse_key) - - # Process WebSocket frames - messages = process_ws_packet(src_ip, src_port, dst_ip, dst_port, raw_payload) - logger.debug(f"process_ws_packet returned {len(messages)} messages") - - for msg in messages: - logger.debug(f"WebSocket message captured: {msg}") - - ts = datetime.now(tz=timezone.utc) - - # In auto-detect mode, log when we find MCP traffic on a new port - if _auto_detect_mode and "jsonrpc" in msg: - print(f"[MCPHawk] Detected WebSocket MCP traffic on port {src_port} -> {dst_port}") - - entry = { - "timestamp": ts, - "src_ip": src_ip, - "src_port": src_port, - "dst_ip": dst_ip, - "dst_port": dst_port, - "direction": "unknown", - "message": msg, - "traffic_type": "TCP/WS", - } - - log_message(entry) - - # Convert timestamp to ISO only for WebSocket broadcast - broadcast_entry = dict(entry) - broadcast_entry["timestamp"] = ts.isoformat() - _broadcast_in_any_loop(broadcast_entry) - - # If this was identified as WebSocket traffic, return early - # even if no complete messages were extracted (could be buffering) - return - - # Otherwise, try to process as raw JSON-RPC + # Try to process as raw JSON-RPC or HTTP POST with JSON-RPC try: decoded = raw_payload.decode("utf-8", errors="ignore") + + # Debug log all HTTP traffic + if decoded.startswith("HTTP/1.1"): + logger.debug(f"HTTP Response: {decoded[:200]}...") + + # Check for standalone SSE data (not part of HTTP response) + if decoded.startswith("data: ") and "jsonrpc" in decoded: + # This is a standalone SSE data packet + sse_data = decoded[6:] # Skip "data: " + if "\n" in sse_data: + sse_data = sse_data[:sse_data.index("\n")] + if sse_data.startswith("{"): + logger.debug(f"Found standalone SSE data: {sse_data[:100]}...") + decoded = sse_data + # Process as regular JSON-RPC + + # Check for HTTP request/response with JSON-RPC content + if (decoded.startswith("POST") or decoded.startswith("HTTP/1.1")) and "\r\n\r\n" in decoded: + # Extract JSON body from HTTP request/response + body_start = decoded.find("\r\n\r\n") + 4 + json_body = decoded[body_start:] + + # Debug log HTTP responses + if decoded.startswith("HTTP/1.1") and "text/event-stream" in decoded: + logger.debug(f"SSE Response detected, body length: {len(json_body)}, body preview: {json_body[:100]}") + + # Check for Server-Sent Events (SSE) format used by MCP SDK + if "text/event-stream" in decoded and json_body.startswith("data: "): + # Extract JSON from SSE format: "data: {...}\n\n" + sse_data = json_body[6:] # Skip "data: " + if "\n" in sse_data: + sse_data = sse_data[:sse_data.index("\n")] + if sse_data.startswith("{") and "jsonrpc" in sse_data: + decoded = sse_data + logger.debug(f"Extracted JSON-RPC from SSE: {decoded[:100]}...") + elif json_body.startswith("{") and "jsonrpc" in json_body: + decoded = json_body # Use just the JSON body + logger.debug(f"Extracted JSON-RPC from HTTP: {decoded[:100]}...") + + # Process if we have JSON-RPC content if decoded.startswith("{") and "jsonrpc" in decoded: logger.debug(f"Sniffer captured: {decoded}") @@ -156,7 +192,18 @@ def packet_callback(pkt): src_ip = "" dst_ip = "" + log_id = str(uuid.uuid4()) + + # Check if we know the transport type for this connection + transport = _tcp_reassembler.transport_tracker.get_transport( + src_ip, src_port, dst_ip, dst_port + ).value + + if _auto_detect_mode and transport != "unknown": + logger.debug(f"Auto-detect: Found transport {transport} for {src_ip}:{src_port} -> {dst_ip}:{dst_port}") + entry = { + "log_id": log_id, "timestamp": ts, "src_ip": src_ip, "src_port": src_port, @@ -164,9 +211,13 @@ def packet_callback(pkt): "dst_port": dst_port, "direction": "unknown", "message": decoded, - "traffic_type": "TCP/Direct", + "transport_type": transport, } + # Add metadata if this is MCPHawk's own MCP traffic + if src_port in _mcphawk_mcp_ports or dst_port in _mcphawk_mcp_ports: + entry["metadata"] = '{"source": "mcphawk-mcp"}' + log_message(entry) # Convert timestamp to ISO only for WebSocket broadcast @@ -177,7 +228,7 @@ def packet_callback(pkt): logger.debug(f"JSON decode failed: {e}") -def start_sniffer(filter_expr: str = "tcp and port 12345", auto_detect: bool = False, debug: bool = False) -> None: +def start_sniffer(filter_expr: str = "tcp and port 12345", auto_detect: bool = False, debug: bool = False, excluded_ports: list[int] | None = None, mcphawk_mcp_ports: list[int] | None = None) -> None: """ Start sniffing packets on the appropriate interface. - On macOS: use `lo0` @@ -187,9 +238,12 @@ def start_sniffer(filter_expr: str = "tcp and port 12345", auto_detect: bool = F filter_expr: BPF filter expression auto_detect: If True, automatically detect MCP traffic on any port debug: If True, enable debug logging + mcphawk_mcp_ports: List of ports where MCPHawk's own MCP server is running """ - global _auto_detect_mode + global _auto_detect_mode, _excluded_ports, _mcphawk_mcp_ports _auto_detect_mode = auto_detect + _excluded_ports = set(excluded_ports) if excluded_ports else set() + _mcphawk_mcp_ports = set(mcphawk_mcp_ports) if mcphawk_mcp_ports else set() # Configure logging based on debug flag if debug: diff --git a/mcphawk/tcp_reassembly.py b/mcphawk/tcp_reassembly.py new file mode 100644 index 0000000..566c85f --- /dev/null +++ b/mcphawk/tcp_reassembly.py @@ -0,0 +1,423 @@ +"""TCP stream reassembly for capturing complete HTTP/SSE responses.""" + +import logging +from typing import Optional + +from scapy.all import IP, TCP, IPv6, Raw + +from .transport_detector import ( + MCPTransport, + TransportTracker, + detect_transport_from_http, + extract_endpoint_from_sse, +) + +logger = logging.getLogger(__name__) +# Ensure we're using the same log level as the parent +logger.setLevel(logging.DEBUG) + + +class StreamKey: + """Key for identifying TCP streams.""" + + def __init__(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int): + # Always order the tuple consistently (lower IP/port first) + if (src_ip, src_port) < (dst_ip, dst_port): + self.key = (src_ip, src_port, dst_ip, dst_port) + else: + self.key = (dst_ip, dst_port, src_ip, src_port) + + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.key == other.key + + def __repr__(self): + return f"StreamKey{self.key}" + + +class HTTPStream: + """Tracks HTTP request/response pairs in a TCP stream.""" + + def __init__(self): + self.pending_request: Optional[bytes] = None + self.pending_response: Optional[bytes] = None + self.response_headers: Optional[dict[str, str]] = None + self.content_length: Optional[int] = None + self.is_chunked: bool = False + self.is_sse: bool = False + self.buffer: bytes = b"" + self.request_method: Optional[str] = None + self.request_path: Optional[str] = None + self.request_headers: dict[str, str] = {} + self.detected_transport: MCPTransport = MCPTransport.UNKNOWN + + def add_request(self, data: bytes): + """Add HTTP request data and parse headers.""" + self.pending_request = data + self.buffer = b"" + logger.debug(f"New HTTP request: {data[:100]}") + + # Parse request line and headers + try: + request_str = data.decode('utf-8', errors='ignore') + lines = request_str.split('\r\n') + if lines: + # Parse request line + parts = lines[0].split(' ') + if len(parts) >= 2: + self.request_method = parts[0] + self.request_path = parts[1] + + # Parse headers + self.request_headers = {} + for line in lines[1:]: + if ': ' in line: + key, value = line.split(': ', 1) + self.request_headers[key.lower()] = value + elif line == '': + break # End of headers + + # Try to detect transport type from request alone + if self.request_method and self.request_path: + self.detected_transport = detect_transport_from_http( + self.request_method, + self.request_path, + self.request_headers, + False # No response yet + ) + logger.debug(f"Transport detected from request: {self.detected_transport}") + except Exception as e: + logger.debug(f"Error parsing request: {e}") + + def add_response_data(self, data: bytes): + """Add HTTP response data, handling headers and body.""" + self.buffer += data + logger.debug(f"HTTPStream: Added {len(data)} bytes to buffer, total buffer size: {len(self.buffer)}") + + # If we don't have headers yet, try to parse them + if self.response_headers is None and b"\r\n\r\n" in self.buffer: + header_end = self.buffer.find(b"\r\n\r\n") + 4 + header_data = self.buffer[:header_end].decode('utf-8', errors='ignore') + self.buffer = self.buffer[header_end:] # Keep only body in buffer + + # Parse headers + lines = header_data.split('\r\n') + self.response_headers = {} + for line in lines[1:]: # Skip status line + if ': ' in line: + key, value = line.split(': ', 1) + self.response_headers[key.lower()] = value + + # Check for SSE + content_type = self.response_headers.get('content-type', '') + self.is_sse = 'text/event-stream' in content_type + + # Check for chunked encoding + transfer_encoding = self.response_headers.get('transfer-encoding', '') + self.is_chunked = 'chunked' in transfer_encoding + + # Get content length if not chunked + if not self.is_chunked and 'content-length' in self.response_headers: + self.content_length = int(self.response_headers['content-length']) + + logger.debug(f"Response headers parsed: SSE={self.is_sse}, chunked={self.is_chunked}") + + # Try to detect transport type + if self.request_method and self.request_path: + self.detected_transport = detect_transport_from_http( + self.request_method, + self.request_path, + self.request_headers, + self.is_sse + ) + + def extract_sse_messages(self) -> list[str]: + """Extract complete SSE messages from buffer.""" + messages = [] + logger.debug(f"extract_sse_messages called, buffer size: {len(self.buffer)}, is_chunked: {self.is_chunked}") + if self.buffer: + logger.debug(f"Buffer content preview: {self.buffer[:100]}") + + # If chunked encoding, we need to handle chunk sizes + data_to_process = self.buffer + if self.is_chunked and self.buffer: + # Try to extract chunks + logger.debug("Attempting to extract chunks from buffer") + chunk_data = self.extract_chunked_data() + if chunk_data: + # Process the extracted chunk data, but don't replace the buffer + # The buffer still contains any remaining chunk data + data_to_process = chunk_data + logger.debug(f"Extracted {len(chunk_data)} bytes from chunks to process") + else: + logger.debug("No complete chunks extracted yet") + return messages # Wait for more data + + # SSE messages are separated by double newlines (could be \r\n\r\n or \n\n) + logger.debug(f"Looking for SSE messages in {len(data_to_process)} bytes of data") + + # Look for either \r\n\r\n or \n\n + while b"\r\n\r\n" in data_to_process or b"\n\n" in data_to_process: + # Find the first occurrence of either separator + crlf_pos = data_to_process.find(b"\r\n\r\n") + lf_pos = data_to_process.find(b"\n\n") + + if crlf_pos >= 0 and (lf_pos < 0 or crlf_pos < lf_pos): + msg_end = crlf_pos + 4 # +4 for \r\n\r\n + else: + msg_end = lf_pos + 2 # +2 for \n\n + + msg_data = data_to_process[:msg_end].decode('utf-8', errors='ignore') + data_to_process = data_to_process[msg_end:] + logger.debug(f"Found SSE message block: {msg_data[:100]!r}") + + # Check for endpoint event (HTTP+SSE transport) + if "event: endpoint" in msg_data: + endpoint_url = extract_endpoint_from_sse(msg_data) + if endpoint_url: + logger.debug(f"Found endpoint event, URL: {endpoint_url}") + self.detected_transport = MCPTransport.HTTP_SSE + + # Extract data lines + for line in msg_data.split('\n'): + if line.startswith('data: '): + json_data = line[6:].strip() + if json_data and json_data.startswith('{'): + messages.append(json_data) + logger.debug(f"Extracted SSE message: {json_data[:100]}") + + return messages + + def extract_chunked_data(self) -> Optional[bytes]: + """Extract data from chunked transfer encoding.""" + # This is a simplified implementation + # Real chunked parsing is more complex + complete_data = b"" + original_buffer = self.buffer + + while True: + # Find chunk size + if b"\r\n" not in self.buffer: + logger.debug("extract_chunked_data: No CRLF in buffer, need more data") + break + + size_end = self.buffer.find(b"\r\n") + chunk_size_str = self.buffer[:size_end].decode('utf-8', errors='ignore').strip() + logger.debug(f"extract_chunked_data: Chunk size string: '{chunk_size_str}'") + + try: + chunk_size = int(chunk_size_str, 16) + logger.debug(f"extract_chunked_data: Parsed chunk size: {chunk_size}") + except ValueError: + logger.debug("extract_chunked_data: Failed to parse chunk size") + break + + if chunk_size == 0: + # Last chunk - consume it from buffer + self.buffer = self.buffer[size_end + 4:] # Skip "0\r\n\r\n" + logger.debug(f"extract_chunked_data: Found last chunk (size 0), returning {len(complete_data)} bytes") + return complete_data if complete_data else None + + # Check if we have the full chunk + chunk_start = size_end + 2 + chunk_end = chunk_start + chunk_size + 2 # +2 for trailing \r\n + + if len(self.buffer) < chunk_end: + # Need more data - restore original buffer and return what we have so far + logger.debug(f"extract_chunked_data: Need more data, have {len(self.buffer)}, need {chunk_end}") + self.buffer = original_buffer + break + + chunk_data = self.buffer[chunk_start:chunk_start + chunk_size] + logger.debug(f"extract_chunked_data: Extracted chunk of {len(chunk_data)} bytes") + complete_data += chunk_data + self.buffer = self.buffer[chunk_end:] + + if complete_data: + logger.debug(f"extract_chunked_data: Returning {len(complete_data)} bytes of unchunked data") + return complete_data + return None # Not complete yet + + +class TCPStreamReassembler: + """Reassembles TCP streams to capture complete HTTP messages.""" + + def __init__(self): + self.streams: dict[StreamKey, HTTPStream] = {} + self.transport_tracker = TransportTracker() + + def process_packet(self, pkt) -> list[dict]: + """Process a packet and return any complete messages.""" + messages = [] + + # Only process TCP packets with data + if not pkt.haslayer(TCP) or not pkt.haslayer(Raw): + return messages + + tcp = pkt[TCP] + + # Debug: Log that we're processing a packet + if pkt.haslayer(Raw): + payload = bytes(pkt[Raw]) + if len(payload) > 0 and (tcp.sport == 8765 or tcp.dport == 8765): + logger.debug(f"TCP reassembly: Processing {tcp.sport}->{tcp.dport} packet with {len(payload)} bytes") + logger.debug(f"TCP reassembly: Packet content: {payload[:100]}") + + # Get stream key + if pkt.haslayer(IP): + src_ip = pkt[IP].src + dst_ip = pkt[IP].dst + elif pkt.haslayer(IPv6): + src_ip = pkt[IPv6].src + dst_ip = pkt[IPv6].dst + else: + return messages + + src_port = pkt[TCP].sport + dst_port = pkt[TCP].dport + payload = bytes(pkt[Raw]) + + # Create stream key + stream_key = StreamKey(src_ip, src_port, dst_ip, dst_port) + + # Debug stream key for port 8765 + if src_port == 8765 or dst_port == 8765: + logger.debug(f"TCP reassembly: Stream key for port 8765: {stream_key}") + + # Get or create stream + if stream_key not in self.streams: + self.streams[stream_key] = HTTPStream() + logger.debug(f"TCP reassembly: Created new stream for {stream_key}") + stream = self.streams[stream_key] + + # Check if this is an HTTP request + if payload.startswith(b"POST ") or payload.startswith(b"GET "): + stream.add_request(payload) + logger.debug(f"TCP reassembly: HTTP request on {src_port}->{dst_port}") + logger.debug(f"TCP reassembly: Request method: {stream.request_method}, path: {stream.request_path}") + logger.debug(f"TCP reassembly: Request headers: {stream.request_headers}") + + # If we detected transport from request, update tracker + if stream.detected_transport != MCPTransport.UNKNOWN: + self.transport_tracker.update_transport( + src_ip, src_port, dst_ip, dst_port, + stream.detected_transport + ) + logger.debug(f"TCP reassembly: Updated transport tracker with {stream.detected_transport} from request") + + # For HTTP+SSE, log detection in auto-detect mode + if stream.detected_transport == MCPTransport.HTTP_SSE: + logger.info(f"[MCPHawk] Detected HTTP+SSE transport on {src_ip}:{src_port} -> {dst_ip}:{dst_port} (GET {stream.request_path} with Accept: text/event-stream)") + else: + logger.debug("TCP reassembly: Transport still unknown after request parsing") + + # Check if this is an HTTP response + elif payload.startswith(b"HTTP/1."): + stream.add_response_data(payload) + logger.debug(f"TCP reassembly: HTTP response data on {src_port}->{dst_port}, SSE={stream.is_sse}, buffer_len={len(stream.buffer)}") + logger.debug(f"TCP reassembly: Stream {stream_key} now has headers: {stream.response_headers}") + + # Try to extract messages if this is SSE + if stream.is_sse: + sse_messages = stream.extract_sse_messages() + logger.debug(f"TCP reassembly: Extracted {len(sse_messages)} SSE messages") + for msg in sse_messages: + # Update transport tracker if detected + if stream.detected_transport != MCPTransport.UNKNOWN: + self.transport_tracker.update_transport( + src_ip, src_port, dst_ip, dst_port, + stream.detected_transport + ) + + messages.append({ + "src_ip": src_ip, + "src_port": src_port, + "dst_ip": dst_ip, + "dst_port": dst_port, + "message": msg, + "type": "sse_response", + "transport": stream.detected_transport.value + }) + + # Check for standalone SSE data (no HTTP headers) + elif payload.startswith(b"data: ") and b"jsonrpc" in payload: + # This might be a continuation of an SSE stream + logger.debug(f"TCP reassembly: Standalone SSE data on {src_port}->{dst_port}") + stream.buffer = payload + sse_messages = stream.extract_sse_messages() + for msg in sse_messages: + # Get transport from tracker + transport = self.transport_tracker.get_transport(src_ip, src_port, dst_ip, dst_port) + messages.append({ + "src_ip": src_ip, + "src_port": src_port, + "dst_ip": dst_ip, + "dst_port": dst_port, + "message": msg, + "type": "sse_data", + "transport": transport.value + }) + + # For any packet, if we have an SSE stream, try to process it + elif stream.response_headers is not None and stream.is_sse and len(payload) > 0: + # This is SSE data following headers + logger.debug(f"TCP reassembly: SSE data on tracked stream {src_port}->{dst_port}, payload preview: {payload[:50]}") + stream.add_response_data(payload) + sse_messages = stream.extract_sse_messages() + if sse_messages: + logger.debug(f"TCP reassembly: Found {len(sse_messages)} messages in SSE stream") + for msg in sse_messages: + # Update transport tracker if detected + if stream.detected_transport != MCPTransport.UNKNOWN: + self.transport_tracker.update_transport( + src_ip, src_port, dst_ip, dst_port, + stream.detected_transport + ) + + messages.append({ + "src_ip": src_ip, + "src_port": src_port, + "dst_ip": dst_ip, + "dst_port": dst_port, + "message": msg, + "type": "sse_continuation", + "transport": stream.detected_transport.value + }) + + # Catch-all: Handle any data on a stream where we've seen headers + elif len(payload) > 0: + logger.debug(f"TCP reassembly: Catch-all for {src_port}->{dst_port}") + # Check if this stream has headers and might be waiting for body data + if stream.response_headers is not None: + logger.debug(f"TCP reassembly: Data on stream with headers {src_port}->{dst_port}, is_sse={stream.is_sse}") + logger.debug(f"TCP reassembly: Data preview: {payload[:50]}") + + # Add data to the stream + stream.add_response_data(payload) + + # If it's an SSE stream, try to extract messages + if stream.is_sse: + logger.debug(f"TCP reassembly: Processing SSE stream, buffer now has {len(stream.buffer)} bytes") + sse_messages = stream.extract_sse_messages() + logger.debug(f"TCP reassembly: Extracted {len(sse_messages)} SSE messages from continuation") + for msg in sse_messages: + messages.append({ + "src_ip": src_ip, + "src_port": src_port, + "dst_ip": dst_ip, + "dst_port": dst_port, + "message": msg, + "type": "sse_continuation" + }) + elif src_port == 8765 or dst_port == 8765: + logger.debug(f"TCP reassembly: Unhandled packet {src_port}->{dst_port}, no headers yet") + logger.debug(f"TCP reassembly: Payload starts with: {payload[:50]}") + + return messages + + def cleanup_old_streams(self, timeout: int = 300): + """Remove old streams to prevent memory leaks.""" + # TODO: Implement timeout-based cleanup + pass diff --git a/mcphawk/transport_detector.py b/mcphawk/transport_detector.py new file mode 100644 index 0000000..9d592e6 --- /dev/null +++ b/mcphawk/transport_detector.py @@ -0,0 +1,197 @@ +"""MCP transport detection logic.""" + +import logging +from enum import Enum +from typing import Optional + +logger = logging.getLogger(__name__) + + +class MCPTransport(Enum): + """MCP transport types.""" + STDIO = "stdio" + STREAMABLE_HTTP = "streamable_http" # Single endpoint, 2025-03-26+ + HTTP_SSE = "http_sse" # Separate endpoints, deprecated + UNKNOWN = "unknown" + + +def detect_transport_from_http( + method: str, + path: str, + headers: dict[str, str], + is_sse_response: bool = False, + response_contains_endpoint_event: bool = False +) -> MCPTransport: + """ + Detect MCP transport type from HTTP traffic. + + HTTP+SSE (legacy): + - GET request with Accept: text/event-stream + - Server sends "endpoint" event with POST URL + - Separate endpoints for SSE and POST + - POST requests don't have special Accept headers + + Streamable HTTP: + - POST requests with Accept: application/json, text/event-stream (BOTH required) + - Single endpoint for all requests + - No "endpoint" event + - Dynamic SSE upgrade when needed + + Key differences in Accept headers: + - HTTP+SSE: GET with Accept: text/event-stream (single type) + - Streamable HTTP: POST with Accept: application/json, text/event-stream (dual types) + + Args: + method: HTTP method (GET, POST) + path: HTTP path + headers: HTTP headers + is_sse_response: Whether response has SSE content-type + response_contains_endpoint_event: Whether SSE stream contains endpoint event + + Returns: + Detected MCP transport type + """ + # HTTP+SSE: GET with Accept: text/event-stream (single type) + if method == "GET" and headers.get("accept", "").lower() == "text/event-stream": + # This is HTTP+SSE establishing SSE connection + if response_contains_endpoint_event: + logger.debug("Detected HTTP+SSE transport (endpoint event found)") + return MCPTransport.HTTP_SSE + # Could still be HTTP+SSE without seeing the response yet + logger.debug("Possible HTTP+SSE transport (GET with SSE accept)") + return MCPTransport.HTTP_SSE + + # Streamable HTTP: POST with Accept: application/json, text/event-stream (BOTH types) + accept_header = headers.get("accept", "").lower() + if method == "POST" and "application/json" in accept_header and "text/event-stream" in accept_header: + logger.debug("Detected Streamable HTTP transport (POST with dual accept)") + return MCPTransport.STREAMABLE_HTTP + + # POST without proper dual Accept headers cannot be identified + # It could be either: + # - Streamable HTTP missing Accept headers (incorrect client) + # - HTTP+SSE posting to the endpoint URL (correct behavior) + + # If we see SSE response without prior transport detection, make best guess + if is_sse_response: + # In Streamable HTTP, SSE is a response to POST + # In HTTP+SSE, SSE is response to GET + logger.debug(f"SSE response detected, method was {method}") + if method == "POST": + return MCPTransport.STREAMABLE_HTTP + else: + return MCPTransport.HTTP_SSE + + logger.debug("Unable to definitively detect transport type") + return MCPTransport.UNKNOWN + + +def extract_endpoint_from_sse(sse_data: str) -> Optional[str]: + """ + Extract endpoint URL from SSE data (for HTTP+SSE transport). + + The endpoint event in HTTP+SSE looks like: + event: endpoint + data: {"url": "http://localhost:8080/messages"} + + Args: + sse_data: SSE message data + + Returns: + Endpoint URL if found, None otherwise + """ + import json + + lines = sse_data.strip().split('\n') + event_type = None + data_content = None + + for line in lines: + if line.startswith('event:'): + event_type = line[6:].strip() + elif line.startswith('data:'): + data_content = line[5:].strip() + + if event_type == "endpoint" and data_content: + try: + data = json.loads(data_content) + return data.get("url") + except json.JSONDecodeError: + logger.debug(f"Failed to parse endpoint data: {data_content}") + + return None + + +class TransportTracker: + """Track transport type per connection and server.""" + + def __init__(self): + self.transports: dict[tuple[str, int, str, int], MCPTransport] = {} + self.endpoint_urls: dict[tuple[str, int, str, int], str] = {} + # Track HTTP+SSE servers by IP:port + self.http_sse_servers: dict[tuple[str, int], MCPTransport] = {} + + def update_transport( + self, + src_ip: str, + src_port: int, + dst_ip: str, + dst_port: int, + transport: MCPTransport + ) -> None: + """Update transport type for a connection.""" + key = (src_ip, src_port, dst_ip, dst_port) + # Also store reverse direction + reverse_key = (dst_ip, dst_port, src_ip, src_port) + + if transport != MCPTransport.UNKNOWN: + self.transports[key] = transport + self.transports[reverse_key] = transport + logger.debug(f"Updated transport for {key}: {transport.value}") + + # For HTTP+SSE, also track the server endpoint + if transport == MCPTransport.HTTP_SSE: + server_key = (dst_ip, dst_port) + self.http_sse_servers[server_key] = transport + logger.debug(f"Marked server {dst_ip}:{dst_port} as HTTP+SSE") + + def get_transport( + self, + src_ip: str, + src_port: int, + dst_ip: str, + dst_port: int + ) -> MCPTransport: + """Get transport type for a connection.""" + key = (src_ip, src_port, dst_ip, dst_port) + + # First check if we have transport for this exact connection + if key in self.transports: + return self.transports[key] + + # For HTTP+SSE, check if the destination OR source is a known HTTP+SSE server + dst_server_key = (dst_ip, dst_port) + if dst_server_key in self.http_sse_servers: + logger.debug(f"Found HTTP+SSE server {dst_ip}:{dst_port} for new connection") + return self.http_sse_servers[dst_server_key] + + # Also check if source is an HTTP+SSE server (for responses) + src_server_key = (src_ip, src_port) + if src_server_key in self.http_sse_servers: + logger.debug(f"Response from HTTP+SSE server {src_ip}:{src_port}") + return self.http_sse_servers[src_server_key] + + return MCPTransport.UNKNOWN + + def store_endpoint_url( + self, + src_ip: str, + src_port: int, + dst_ip: str, + dst_port: int, + endpoint_url: str + ) -> None: + """Store endpoint URL for HTTP+SSE transport.""" + key = (src_ip, src_port, dst_ip, dst_port) + self.endpoint_urls[key] = endpoint_url + logger.debug(f"Stored endpoint URL for {key}: {endpoint_url}") diff --git a/mcphawk/utils.py b/mcphawk/utils.py new file mode 100644 index 0000000..13519a5 --- /dev/null +++ b/mcphawk/utils.py @@ -0,0 +1,53 @@ +"""Utility functions for MCPHawk.""" + +import json +from typing import Any, Optional + + +def parse_message(message: str) -> Optional[dict[str, Any]]: + """Parse a JSON message string.""" + try: + if isinstance(message, str): + return json.loads(message) + return message + except (json.JSONDecodeError, TypeError): + return None + + +def get_message_type(message: str) -> str: + """ + Determine the type of a JSON-RPC message. + + Returns: 'request', 'response', 'notification', 'error', or 'unknown' + """ + parsed = parse_message(message) + if not parsed: + return "unknown" + + # Check if it's a valid JSON-RPC 2.0 message + if parsed.get("jsonrpc") != "2.0": + return "unknown" + + # Error response (has error and id) + if "error" in parsed and "id" in parsed: + return "error" + + # Response (has result and id) + if "result" in parsed and "id" in parsed: + return "response" + + # Request (has method and id) + if "method" in parsed and "id" in parsed: + return "request" + + # Notification (has method but no id) + if "method" in parsed and "id" not in parsed: + return "notification" + + return "unknown" + + +def get_method_name(message: str) -> Optional[str]: + """Extract method name from a JSON-RPC message.""" + parsed = parse_message(message) + return parsed.get("method") if parsed else None diff --git a/mcphawk/web/server.py b/mcphawk/web/server.py index 7987fd1..9243020 100644 --- a/mcphawk/web/server.py +++ b/mcphawk/web/server.py @@ -15,6 +15,9 @@ # Set up logger for this module logger = logging.getLogger(__name__) +# Global flag to track if web server was started with MCP +_with_mcp = False + app = FastAPI() # Allow local UI dev or CDN-based dashboard @@ -27,6 +30,16 @@ ) +@app.get("/status") +def get_status(): + """ + Get server status including MCP server status. + """ + return JSONResponse(content={ + "with_mcp": _with_mcp + }) + + @app.get("/logs") def get_logs(limit: int = 50): """ @@ -44,7 +57,7 @@ def get_logs(limit: int = 50): { **log, "timestamp": log["timestamp"].isoformat(), # ensure JSON-friendly - "traffic_type": log.get("traffic_type", "N/A") # ensure traffic_type is included + "transport_type": log.get("transport_type", "unknown") # ensure transport_type is included } for log in logs ]) @@ -83,7 +96,7 @@ async def websocket_endpoint(websocket: WebSocket): logger.debug(f"WebSocket disconnected: {len(active_clients)} active clients") -def _start_sniffer_thread(filter_expr: str, auto_detect: bool = False, debug: bool = False): +def _start_sniffer_thread(filter_expr: str, auto_detect: bool = False, debug: bool = False, excluded_ports: Optional[list[int]] = None, mcphawk_mcp_ports: Optional[list[int]] = None): """ Start the sniffer in a dedicated daemon thread. @@ -91,18 +104,20 @@ def _start_sniffer_thread(filter_expr: str, auto_detect: bool = False, debug: bo filter_expr: BPF filter expression for the sniffer. auto_detect: Whether to auto-detect MCP traffic. debug: Whether to enable debug logging. + excluded_ports: List of ports to exclude from capture. + mcphawk_mcp_ports: List of ports where MCPHawk's own MCP server is running. """ from mcphawk.sniffer import start_sniffer def safe_start(): logger.debug(f"Sniffer thread starting with filter: {filter_expr}, auto_detect: {auto_detect}") - return start_sniffer(filter_expr=filter_expr, auto_detect=auto_detect, debug=debug) + return start_sniffer(filter_expr=filter_expr, auto_detect=auto_detect, debug=debug, excluded_ports=excluded_ports, mcphawk_mcp_ports=mcphawk_mcp_ports) thread = threading.Thread(target=safe_start, daemon=True) thread.start() -def run_web(sniffer: bool = True, host: str = "127.0.0.1", port: int = 8000, filter_expr: Optional[str] = None, auto_detect: bool = False, debug: bool = False): +def run_web(sniffer: bool = True, host: str = "127.0.0.1", port: int = 8000, filter_expr: Optional[str] = None, auto_detect: bool = False, debug: bool = False, excluded_ports: Optional[list[int]] = None, with_mcp: bool = False, mcphawk_mcp_ports: Optional[list[int]] = None): """ Run the web server and optionally the sniffer. @@ -114,10 +129,14 @@ def run_web(sniffer: bool = True, host: str = "127.0.0.1", port: int = 8000, fil auto_detect: Whether to auto-detect MCP traffic. debug: Whether to enable debug logging. """ + # Set global MCP flag + global _with_mcp + _with_mcp = with_mcp + if sniffer: if not filter_expr: raise ValueError("filter_expr is required when sniffer is enabled") - _start_sniffer_thread(filter_expr, auto_detect, debug) + _start_sniffer_thread(filter_expr, auto_detect, debug, excluded_ports, mcphawk_mcp_ports) if sniffer: print(f"[MCPHawk] Starting sniffer and dashboard on http://{host}:{port}") diff --git a/mcphawk/ws_reassembly.py b/mcphawk/ws_reassembly.py deleted file mode 100644 index d8392c4..0000000 --- a/mcphawk/ws_reassembly.py +++ /dev/null @@ -1,163 +0,0 @@ -# Track buffers for fragmented WebSocket messages per connection -ws_buffers: dict[tuple[str, int, str, int], bytearray] = {} - - -def _parse_ws_frames(data: bytes) -> tuple[list[tuple[bool, str]], int]: - """ - Parse one or more WebSocket frames. - - Returns (messages, consumed_bytes) where messages is a list of (fin, text) tuples. - """ - messages = [] - i = 0 - buffer_len = len(data) - - # Debug logging - import logging - logger = logging.getLogger(__name__) - logger.debug(f"Parsing WebSocket frames from {buffer_len} bytes") - - while i < buffer_len: - start_pos = i # Remember where this frame started - - # Need at least 2 bytes for header - if i + 2 > buffer_len: - break - - first_byte = data[i] - fin = (first_byte & 0x80) != 0 - opcode = first_byte & 0x0F - i += 1 - - length_byte = data[i] - masked = (length_byte & 0x80) != 0 - length = length_byte & 0x7F - i += 1 - - # Extended length handling - if length == 126: - if i + 2 > buffer_len: - i = start_pos # Reset to frame start - break - length = int.from_bytes(data[i:i+2], "big") - i += 2 - elif length == 127: - if i + 8 > buffer_len: - i = start_pos # Reset to frame start - break - length = int.from_bytes(data[i:i+8], "big") - i += 8 - - # Check if we have the complete frame - if masked: - if i + 4 + length > buffer_len: - i = start_pos # Reset to frame start - break - mask = data[i:i+4] - i += 4 - - payload = bytearray(data[i:i+length]) - # Unmask the payload - for j in range(length): - payload[j] ^= mask[j % 4] - i += length - else: - if i + length > buffer_len: - i = start_pos # Reset to frame start - break - payload = data[i:i+length] - i += length - - # Process the payload based on opcode - if opcode == 0x1: # Text frame - try: - msg = payload.decode("utf-8") - messages.append((fin, msg)) - logger.debug(f"Decoded text frame: {msg[:100]}...") - except UnicodeDecodeError: - logger.debug("Failed to decode text frame") - continue - elif opcode == 0x0: # Continuation frame - try: - msg = payload.decode("utf-8") - messages.append((fin, msg)) - except UnicodeDecodeError: - continue - - return messages, i - - -def process_ws_packet( - src_ip: str, - src_port: int, - dst_ip: str, - dst_port: int, - payload: bytes -) -> list[str]: - """ - Process TCP payloads and reconstruct WebSocket text messages. - - Args: - src_ip: Source IP address. - src_port: Source port. - dst_ip: Destination IP address. - dst_port: Destination port. - payload: Raw TCP payload bytes. - - Returns: - List of completed JSON strings (if any). - """ - import logging - logger = logging.getLogger(__name__) - logger.debug(f"Processing WS packet from {src_ip}:{src_port} to {dst_ip}:{dst_port}, {len(payload)} bytes") - - # Check if this is HTTP upgrade request/response - if payload.startswith(b'GET ') or payload.startswith(b'HTTP/'): - logger.debug("Detected HTTP upgrade, skipping frame parsing") - return [] - - key = (src_ip, src_port, dst_ip, dst_port) - - if key not in ws_buffers: - ws_buffers[key] = bytearray() - - # Append new data to buffer - ws_buffers[key].extend(payload) - logger.debug(f"Buffer size for {key}: {len(ws_buffers[key])} bytes") - - # Try to parse all complete frames from the buffer - try: - parsed_messages, consumed = _parse_ws_frames(bytes(ws_buffers[key])) - - # Remove consumed bytes from buffer - if consumed > 0: - ws_buffers[key] = ws_buffers[key][consumed:] - logger.debug(f"Consumed {consumed} bytes, {len(ws_buffers[key])} bytes remain in buffer") - - # Extract complete messages from parsed frames - combined = [] - current = "" - for fin, part in parsed_messages: - current += part - if fin: - combined.append(current) - current = "" - - # Filter for JSON-RPC messages - json_messages = [m for m in combined if "jsonrpc" in m] - if json_messages: - logger.debug(f"Found {len(json_messages)} JSON-RPC messages") - - # Clean up large buffers to prevent memory issues - if len(ws_buffers[key]) > 1024 * 1024: # 1MB limit - logger.warning(f"Buffer for {key} exceeded 1MB, clearing") - ws_buffers[key] = bytearray() - - return json_messages - - except Exception as e: - logger.error(f"Error parsing WebSocket frames: {e}") - # Clear buffer on error to recover - ws_buffers[key] = bytearray() - return [] - diff --git a/requirements-dev.txt b/requirements-dev.txt index d6ff5d5..a0a5e1f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,5 @@ pytest>=8.0.0 pytest-asyncio==1.1.0 pytest-cov>=5.0.0 ruff>=0.8.0 -websockets>=12.0 \ No newline at end of file +websockets>=12.0 +requests>=2.31.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d1f8b5b..6f7aab6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ scapy==2.5.0 psutil==6.0.0 typer>=0.12.0 uvicorn==0.35.0 -wsproto==1.2.0 \ No newline at end of file +wsproto==1.2.0 +mcp>=1.0.0 \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_cli.py index f39a838..a47036f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,11 +11,12 @@ def test_cli_help(): - """Test that CLI help shows both commands.""" + """Test that CLI help shows all commands.""" result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 assert "sniff" in result.stdout assert "web" in result.stdout + assert "mcp" in result.stdout assert "MCPHawk: Passive MCP traffic sniffer + dashboard" in result.stdout @@ -45,12 +46,262 @@ def test_sniff_command_requires_flags(): """Test sniff command requires port, filter, or auto-detect.""" result = runner.invoke(app, ["sniff"]) assert result.exit_code == 1 - assert "[ERROR] You must specify either --port, --filter, or --auto-detect" in result.stdout + assert "You must specify either --port, --filter, or --auto-detect" in result.stdout assert "mcphawk sniff --port 3000" in result.stdout assert "mcphawk sniff --filter 'tcp port 3000 or tcp port 3001'" in result.stdout assert "mcphawk sniff --auto-detect" in result.stdout +def test_mcp_command_help(): + """Test mcp command help.""" + result = runner.invoke(app, ["mcp", "--help"]) + assert result.exit_code == 0 + assert "Run MCPHawk MCP server" in result.stdout + # Check for transport option (may be formatted as --transport or -transport) + assert "transport" in result.stdout + assert "stdio or tcp" in result.stdout + # Check for mcp port option (may be formatted as --mcp-port, -mcp-port, or broken across lines) + assert "mcp" in result.stdout and "port" in result.stdout + assert "Port for TCP transport" in result.stdout + + +def test_mcp_command_stdio_transport(): + """Test mcp command with stdio transport.""" + with patch('mcphawk.cli.MCPHawkServer') as mock_server_class, \ + patch('mcphawk.cli.asyncio.run') as mock_asyncio_run: + + mock_server_instance = mock_server_class.return_value + + result = runner.invoke(app, ["mcp", "--transport", "stdio"]) + + # Check output + assert "Starting MCP server (transport: stdio)" in result.stdout + # The debug output with mcpServers only shows up with debug flag + # So we don't check for it here + + # Verify server was created and run_stdio was called + mock_server_class.assert_called_once() + mock_asyncio_run.assert_called_once() + + # Verify that asyncio.run was called with the server's run_stdio method + assert mock_asyncio_run.called + # The coroutine passed should be from run_stdio + assert mock_server_instance.run_stdio.called + + +def test_mcp_command_http_transport(): + """Test mcp command with HTTP transport.""" + with patch('mcphawk.cli.MCPHawkServer') as mock_server_class, \ + patch('mcphawk.cli.asyncio.run') as mock_asyncio_run: + + mock_server_instance = mock_server_class.return_value + + result = runner.invoke(app, ["mcp", "--transport", "http", "--mcp-port", "8765"]) + + # Check output + assert "Starting MCP server (transport: http)" in result.stdout + assert "http://localhost:8765/mcp" in result.stdout + # curl example only shows in debug mode + + # Verify server was created and run_http was called + mock_server_class.assert_called_once() + mock_asyncio_run.assert_called_once() + + # Verify that asyncio.run was called with the server's run_http method + assert mock_asyncio_run.called + # The coroutine passed should be from run_http + assert mock_server_instance.run_http.called + + +def test_mcp_command_unknown_transport(): + """Test mcp command with unknown transport.""" + result = runner.invoke(app, ["mcp", "--transport", "websocket"]) + assert result.exit_code == 1 + assert "Unknown transport: websocket" in result.stdout + + +def test_sniff_with_mcp_http(): + """Test sniff command with MCP HTTP transport.""" + with patch('mcphawk.cli.start_sniffer') as mock_start_sniffer, \ + patch('mcphawk.cli.MCPHawkServer'), \ + patch('mcphawk.cli.threading.Thread') as mock_thread: + + mock_thread_instance = mock_thread.return_value + + result = runner.invoke(app, [ + "sniff", + "--port", "3000", + "--with-mcp", + "--mcp-transport", "http", + "--mcp-port", "8765" + ]) + + # Check MCP server startup message + assert "[MCPHawk] Starting MCP HTTP server on http://localhost:8765/mcp" in result.stdout + + # Verify thread was started for MCP server + mock_thread.assert_called_once() + mock_thread_instance.start.assert_called_once() + + # Verify sniffer was called with excluded ports + mock_start_sniffer.assert_called_once() + call_args = mock_start_sniffer.call_args[1] + assert call_args['excluded_ports'] == [8765] + + +def test_sniff_with_mcp_stdio(): + """Test sniff command with MCP stdio transport.""" + with patch('mcphawk.cli.start_sniffer') as mock_start_sniffer, \ + patch('mcphawk.cli.MCPHawkServer'), \ + patch('mcphawk.cli.threading.Thread'): + + result = runner.invoke(app, [ + "sniff", + "--port", "3000", + "--with-mcp", + "--mcp-transport", "stdio" + ]) + + # Check MCP server startup message + assert "[MCPHawk] Starting MCP server on stdio" in result.stdout + + # Verify sniffer was called with empty excluded ports + mock_start_sniffer.assert_called_once() + call_args = mock_start_sniffer.call_args[1] + assert call_args['excluded_ports'] == [] + + +def test_web_with_mcp_http(): + """Test web command with MCP HTTP transport.""" + with patch('mcphawk.cli.run_web') as mock_run_web, \ + patch('mcphawk.cli.MCPHawkServer'), \ + patch('mcphawk.cli.threading.Thread'): + + result = runner.invoke(app, [ + "web", + "--port", "3000", + "--with-mcp", + "--mcp-transport", "http", + "--mcp-port", "8766" + ]) + + # Check MCP server startup message + assert "[MCPHawk] Starting MCP HTTP server on http://localhost:8766/mcp" in result.stdout + + # Verify web was called with excluded ports + mock_run_web.assert_called_once() + call_args = mock_run_web.call_args[1] + assert call_args['excluded_ports'] == [8766] + + +def test_mcp_command_custom_port(): + """Test mcp command with custom HTTP port.""" + with patch('mcphawk.cli.MCPHawkServer') as mock_server_class, \ + patch('mcphawk.cli.asyncio.run') as mock_asyncio_run: + + mock_server_instance = mock_server_class.return_value + + result = runner.invoke(app, ["mcp", "--transport", "http", "--mcp-port", "9999"]) + + # Check output shows custom port + assert "Starting MCP server (transport: http)" in result.stdout + assert "http://localhost:9999/mcp" in result.stdout + + # Verify server was created + mock_server_class.assert_called_once() + + # Verify run_http was called with custom port + mock_asyncio_run.assert_called_once() + # Check it was called with port=9999 + assert mock_server_instance.run_http.call_args[1]['port'] == 9999 + + +def test_sniff_with_mcp_custom_port(): + """Test sniff command with MCP on custom port.""" + with patch('mcphawk.cli.start_sniffer') as mock_start_sniffer, \ + patch('mcphawk.cli.MCPHawkServer'), \ + patch('mcphawk.cli.threading.Thread') as mock_thread: + + mock_thread_instance = mock_thread.return_value + + result = runner.invoke(app, [ + "sniff", + "--port", "3000", + "--with-mcp", + "--mcp-transport", "http", + "--mcp-port", "7777" + ]) + + # Check MCP server startup message with custom port + assert "[MCPHawk] Starting MCP HTTP server on http://localhost:7777/mcp" in result.stdout + + # Verify thread was started for MCP server + mock_thread.assert_called_once() + mock_thread_instance.start.assert_called_once() + + # Verify sniffer was called with custom port excluded + mock_start_sniffer.assert_called_once() + call_args = mock_start_sniffer.call_args[1] + assert call_args['excluded_ports'] == [7777] + + +def test_mcp_stdio_ignores_port(): + """Test that stdio transport ignores the mcp-port parameter.""" + with patch('mcphawk.cli.MCPHawkServer') as mock_server_class, \ + patch('mcphawk.cli.asyncio.run'): + + mock_server_instance = mock_server_class.return_value + + # Even with --mcp-port specified, stdio should ignore it + result = runner.invoke(app, ["mcp", "--transport", "stdio", "--mcp-port", "9999"]) + + # Check output doesn't mention the port + assert "Starting MCP server (transport: stdio)" in result.stdout + assert "9999" not in result.stdout + # mcpServers only shows in debug output + + # Verify run_stdio was called (not run_http) + assert mock_server_instance.run_stdio.called + assert not mock_server_instance.run_http.called + + +def test_web_with_mcp_default_vs_custom_port(): + """Test that default port 8765 is used when not specified.""" + with patch('mcphawk.cli.run_web') as mock_run_web, \ + patch('mcphawk.cli.MCPHawkServer') as mock_server_class, \ + patch('mcphawk.cli.threading.Thread') as mock_thread: + + # Test 1: Default port + result = runner.invoke(app, [ + "web", + "--port", "3000", + "--with-mcp", + "--mcp-transport", "http" + ]) + + assert "[MCPHawk] Starting MCP HTTP server on http://localhost:8765/mcp" in result.stdout + call_args = mock_run_web.call_args[1] + assert call_args['excluded_ports'] == [8765] + + # Reset mocks + mock_run_web.reset_mock() + mock_server_class.reset_mock() + mock_thread.reset_mock() + + # Test 2: Custom port + result = runner.invoke(app, [ + "web", + "--port", "3000", + "--with-mcp", + "--mcp-transport", "http", + "--mcp-port", "5555" + ]) + + assert "[MCPHawk] Starting MCP HTTP server on http://localhost:5555/mcp" in result.stdout + call_args = mock_run_web.call_args[1] + assert call_args['excluded_ports'] == [5555] + + @patch('mcphawk.cli.start_sniffer') def test_sniff_command_with_port(mock_start_sniffer): """Test sniff command with port option.""" @@ -60,7 +311,7 @@ def test_sniff_command_with_port(mock_start_sniffer): assert result.exit_code == 0 assert "[MCPHawk] Starting sniffer with filter: tcp port 3000" in result.stdout assert "[MCPHawk] Sniffer stopped." in result.stdout - mock_start_sniffer.assert_called_once_with(filter_expr="tcp port 3000", auto_detect=False, debug=False) + mock_start_sniffer.assert_called_once_with(filter_expr="tcp port 3000", auto_detect=False, debug=False, excluded_ports=[], mcphawk_mcp_ports=[]) @patch('mcphawk.cli.start_sniffer') @@ -71,7 +322,7 @@ def test_sniff_command_custom_filter(mock_start_sniffer): result = runner.invoke(app, ["sniff", "--filter", "tcp port 8080"]) assert result.exit_code == 0 assert "[MCPHawk] Starting sniffer with filter: tcp port 8080" in result.stdout - mock_start_sniffer.assert_called_once_with(filter_expr="tcp port 8080", auto_detect=False, debug=False) + mock_start_sniffer.assert_called_once_with(filter_expr="tcp port 8080", auto_detect=False, debug=False, excluded_ports=[], mcphawk_mcp_ports=[]) @patch('mcphawk.cli.start_sniffer') @@ -83,14 +334,14 @@ def test_sniff_command_auto_detect(mock_start_sniffer): assert result.exit_code == 0 assert "[MCPHawk] Auto-detect mode: monitoring all TCP traffic for MCP messages" in result.stdout assert "[MCPHawk] Starting sniffer with filter: tcp" in result.stdout - mock_start_sniffer.assert_called_once_with(filter_expr="tcp", auto_detect=True, debug=False) + mock_start_sniffer.assert_called_once_with(filter_expr="tcp", auto_detect=True, debug=False, excluded_ports=[], mcphawk_mcp_ports=[]) def test_web_command_requires_flags(): """Test web command requires port, filter, auto-detect, or no-sniffer.""" result = runner.invoke(app, ["web"]) assert result.exit_code == 1 - assert "[ERROR] You must specify either --port, --filter, or --auto-detect (or use --no-sniffer)" in result.stdout + assert "You must specify either --port, --filter, or --auto-detect (or use --no-sniffer)" in result.stdout assert "mcphawk web --port 3000" in result.stdout assert "mcphawk web --filter 'tcp port 3000 or tcp port 3001'" in result.stdout assert "mcphawk web --auto-detect" in result.stdout @@ -102,7 +353,7 @@ def test_web_command_with_port(mock_run_web): """Test web command with port option.""" result = runner.invoke(app, ["web", "--port", "3000"]) assert result.exit_code == 0 - mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp port 3000", auto_detect=False, debug=False) + mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp port 3000", auto_detect=False, debug=False, excluded_ports=[], with_mcp=False, mcphawk_mcp_ports=[]) @patch('mcphawk.cli.run_web') @@ -110,7 +361,7 @@ def test_web_command_no_sniffer(mock_run_web): """Test web command with --no-sniffer.""" result = runner.invoke(app, ["web", "--no-sniffer"]) assert result.exit_code == 0 - mock_run_web.assert_called_once_with(sniffer=False, host="127.0.0.1", port=8000, filter_expr=None, auto_detect=False, debug=False) + mock_run_web.assert_called_once_with(sniffer=False, host="127.0.0.1", port=8000, filter_expr=None, auto_detect=False, debug=False, excluded_ports=[], with_mcp=False, mcphawk_mcp_ports=[]) @patch('mcphawk.cli.run_web') @@ -118,7 +369,7 @@ def test_web_command_custom_host_web_port(mock_run_web): """Test web command with custom host and web-port.""" result = runner.invoke(app, ["web", "--port", "3000", "--host", "0.0.0.0", "--web-port", "9000"]) assert result.exit_code == 0 - mock_run_web.assert_called_once_with(sniffer=True, host="0.0.0.0", port=9000, filter_expr="tcp port 3000", auto_detect=False, debug=False) + mock_run_web.assert_called_once_with(sniffer=True, host="0.0.0.0", port=9000, filter_expr="tcp port 3000", auto_detect=False, debug=False, excluded_ports=[], with_mcp=False, mcphawk_mcp_ports=[]) @patch('mcphawk.cli.run_web') @@ -126,7 +377,7 @@ def test_web_command_with_filter(mock_run_web): """Test web command with custom filter.""" result = runner.invoke(app, ["web", "--filter", "tcp port 8080 or tcp port 8081"]) assert result.exit_code == 0 - mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp port 8080 or tcp port 8081", auto_detect=False, debug=False) + mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp port 8080 or tcp port 8081", auto_detect=False, debug=False, excluded_ports=[], with_mcp=False, mcphawk_mcp_ports=[]) @patch('mcphawk.cli.run_web') @@ -134,7 +385,7 @@ def test_web_command_auto_detect(mock_run_web): """Test web command with auto-detect mode.""" result = runner.invoke(app, ["web", "--auto-detect"]) assert result.exit_code == 0 - mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp", auto_detect=True, debug=False) + mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp", auto_detect=True, debug=False, excluded_ports=[], with_mcp=False, mcphawk_mcp_ports=[]) def test_scapy_warnings_suppressed(): @@ -162,7 +413,7 @@ def test_sniff_command_with_debug_flag(mock_start_sniffer): result = runner.invoke(app, ["sniff", "--port", "3000", "--debug"]) assert result.exit_code == 0 - mock_start_sniffer.assert_called_once_with(filter_expr="tcp port 3000", auto_detect=False, debug=True) + mock_start_sniffer.assert_called_once_with(filter_expr="tcp port 3000", auto_detect=False, debug=True, excluded_ports=[], mcphawk_mcp_ports=[]) @patch('mcphawk.cli.run_web') @@ -170,4 +421,42 @@ def test_web_command_with_debug_flag(mock_run_web): """Test web command with debug flag.""" result = runner.invoke(app, ["web", "--port", "3000", "--debug"]) assert result.exit_code == 0 - mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp port 3000", auto_detect=False, debug=True) + mock_run_web.assert_called_once_with(sniffer=True, host="127.0.0.1", port=8000, filter_expr="tcp port 3000", auto_detect=False, debug=True, excluded_ports=[], with_mcp=False, mcphawk_mcp_ports=[]) + + +@patch('mcphawk.cli.run_web') +@patch('mcphawk.cli.MCPHawkServer') +@patch('mcphawk.cli.threading.Thread') +def test_web_command_with_mcp(mock_thread, mock_mcp_server, mock_run_web): + """Test web command with MCP server integration.""" + result = runner.invoke(app, ["web", "--port", "3000", "--with-mcp"]) + assert result.exit_code == 0 + + # Check MCP server was created + mock_mcp_server.assert_called_once() + + # Check thread was started + mock_thread.assert_called_once() + mock_thread.return_value.start.assert_called_once() + + # Check run_web was called with excluded ports + # Default MCP transport is HTTP on port 8765 + # In non-debug mode, mcphawk_mcp_ports is empty + mock_run_web.assert_called_once_with( + sniffer=True, + host="127.0.0.1", + port=8000, + filter_expr="tcp port 3000", + auto_detect=False, + debug=False, + excluded_ports=[8765], # Default HTTP MCP port is excluded + with_mcp=True, + mcphawk_mcp_ports=[] # Empty in non-debug mode + ) + + +def test_mcp_command(): + """Test standalone MCP command.""" + result = runner.invoke(app, ["mcp", "--help"]) + assert result.exit_code == 0 + assert "Run MCPHawk MCP server standalone" in result.stdout diff --git a/tests/test_ipv4_ipv6_capture.py b/tests/test_ipv4_ipv6_capture.py index 37f46cf..f1e4712 100644 --- a/tests/test_ipv4_ipv6_capture.py +++ b/tests/test_ipv4_ipv6_capture.py @@ -1,4 +1,4 @@ -"""Test IPv4 and IPv6 traffic capture for both TCP/Direct and TCP/WS.""" +"""Test IPv4 and IPv6 traffic capture for unknown.""" import json import socket import time @@ -26,7 +26,7 @@ class TestIPv4Capture: """Test IPv4 traffic capture.""" def test_ipv4_tcp_direct(self, test_db): - """Test IPv4 TCP/Direct traffic capture.""" + """Test IPv4 unknown traffic capture.""" json_rpc = json.dumps({"jsonrpc": "2.0", "method": "ipv4_tcp_test", "id": 1}) # Create IPv4 packet @@ -37,37 +37,18 @@ def test_ipv4_tcp_direct(self, test_db): logs = fetch_logs(limit=1) assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/Direct" + assert logs[0]["transport_type"] == "unknown" assert logs[0]["src_ip"] == "127.0.0.1" assert logs[0]["dst_ip"] == "127.0.0.1" assert "ipv4_tcp_test" in logs[0]["message"] - def test_ipv4_tcp_ws(self, test_db): - """Test IPv4 TCP/WS (WebSocket) traffic capture.""" - json_rpc = json.dumps({"jsonrpc": "2.0", "method": "ipv4_ws_test", "id": 2}) - - # Create WebSocket frame - frame = bytes([0x81, len(json_rpc)]) + json_rpc.encode() - - # Create IPv4 packet - pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=8765, dport=54321) / Raw(load=frame) - packet_callback(pkt) - - time.sleep(0.1) - - logs = fetch_logs(limit=1) - assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/WS" - assert logs[0]["src_ip"] == "127.0.0.1" - assert logs[0]["dst_ip"] == "127.0.0.1" - assert "ipv4_ws_test" in logs[0]["message"] class TestIPv6Capture: """Test IPv6 traffic capture.""" def test_ipv6_tcp_direct(self, test_db): - """Test IPv6 TCP/Direct traffic capture.""" + """Test IPv6 unknown traffic capture.""" json_rpc = json.dumps({"jsonrpc": "2.0", "method": "ipv6_tcp_test", "id": 3}) # Create IPv6 packet @@ -81,33 +62,11 @@ def test_ipv6_tcp_direct(self, test_db): if len(logs) == 0: pytest.skip("IPv6 traffic not captured - known limitation") else: - assert logs[0]["traffic_type"] == "TCP/Direct" + assert logs[0]["transport_type"] == "unknown" assert logs[0]["src_ip"] == "::1" assert logs[0]["dst_ip"] == "::1" assert "ipv6_tcp_test" in logs[0]["message"] - def test_ipv6_tcp_ws(self, test_db): - """Test IPv6 TCP/WS (WebSocket) traffic capture.""" - json_rpc = json.dumps({"jsonrpc": "2.0", "method": "ipv6_ws_test", "id": 4}) - - # Create WebSocket frame - frame = bytes([0x81, len(json_rpc)]) + json_rpc.encode() - - # Create IPv6 packet - pkt = IPv6(src="::1", dst="::1") / TCP(sport=8765, dport=54321) / Raw(load=frame) - packet_callback(pkt) - - time.sleep(0.1) - - logs = fetch_logs(limit=1) - # This test will show if IPv6 is captured - if len(logs) == 0: - pytest.skip("IPv6 traffic not captured - known limitation") - else: - assert logs[0]["traffic_type"] == "TCP/WS" - assert logs[0]["src_ip"] == "::1" - assert logs[0]["dst_ip"] == "::1" - assert "ipv6_ws_test" in logs[0]["message"] class TestRealSocketCapture: diff --git a/tests/test_mcp_http_simple.py b/tests/test_mcp_http_simple.py new file mode 100644 index 0000000..639dc4b --- /dev/null +++ b/tests/test_mcp_http_simple.py @@ -0,0 +1,133 @@ +"""Simple HTTP tests for MCP server with mocks.""" + +import json +from unittest.mock import MagicMock, patch + + +class TestMCPHTTPSimple: + """Test MCP server HTTP transport with mocks.""" + + def test_sse_response_format(self): + """Test that we correctly parse SSE format responses.""" + # Mock SSE response + sse_response = 'data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05"}}\n\n' + + # Parse SSE response + lines = sse_response.strip().split('\n') + assert lines[0].startswith('data: ') + data = json.loads(lines[0][6:]) + assert data["id"] == 1 + assert data["result"]["protocolVersion"] == "2024-11-05" + + @patch('requests.post') + def test_http_basic_flow_mocked(self, mock_post): + """Test basic HTTP flow with mocked responses.""" + # Mock responses for each call + mock_responses = [ + # Initialize response + MagicMock( + status_code=200, + text='data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{},"serverInfo":{"name":"mcphawk-mcp","version":"1.0.0"}}}\n\n', + headers={"mcp-session-id": "test-session-123"} + ), + # List tools response + MagicMock( + status_code=200, + text='data: {"jsonrpc":"2.0","id":2,"result":{"tools":[{"name":"query_traffic"},{"name":"get_log"},{"name":"search_traffic"},{"name":"get_stats"},{"name":"list_methods"}]}}\n\n' + ), + # Get stats response + MagicMock( + status_code=200, + text='data: {"jsonrpc":"2.0","id":3,"result":{"content":[{"type":"text","text":"{\\"total\\":10,\\"requests\\":3,\\"responses\\":3,\\"notifications\\":2,\\"errors\\":2}"}]}}\n\n' + ) + ] + mock_post.side_effect = mock_responses + + # Import here to avoid issues + import requests + + # 1. Initialize + response = requests.post( + "http://localhost:8765/mcp", + json={"jsonrpc": "2.0", "method": "initialize", "params": {}, "id": 1}, + headers={"Accept": "application/json, text/event-stream"} + ) + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id == "test-session-123" + + # 2. List tools + response = requests.post( + "http://localhost:8765/mcp", + json={"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 2}, + headers={"Accept": "application/json, text/event-stream", "mcp-session-id": session_id} + ) + assert response.status_code == 200 + + # Parse and check tools + lines = response.text.strip().split('\n') + data = json.loads(lines[0][6:]) + assert len(data["result"]["tools"]) == 5 + + # 3. Call get_stats + response = requests.post( + "http://localhost:8765/mcp", + json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "get_stats"}, "id": 3}, + headers={"Accept": "application/json, text/event-stream", "mcp-session-id": session_id} + ) + assert response.status_code == 200 + + # Parse stats + lines = response.text.strip().split('\n') + data = json.loads(lines[0][6:]) + stats = json.loads(data["result"]["content"][0]["text"]) + assert stats["total"] == 10 + assert stats["requests"] == 3 + + @patch('requests.post') + def test_http_error_handling_mocked(self, mock_post): + """Test HTTP error handling with mocked responses.""" + # Mock error responses + mock_responses = [ + # Session not initialized error + MagicMock( + status_code=200, + text='data: {"jsonrpc":"2.0","id":1,"error":{"code":-32602,"message":"Session not initialized"}}\n\n' + ), + # Unknown method error + MagicMock( + status_code=200, + text='data: {"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"Unknown method: unknown/method"}}\n\n' + ) + ] + mock_post.side_effect = mock_responses + + import requests + + # Try to call tool without initialization + response = requests.post( + "http://localhost:8765/mcp", + json={"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 1}, + headers={"Accept": "application/json, text/event-stream", "mcp-session-id": "uninitialized"} + ) + assert response.status_code == 200 + + # Parse SSE response + lines = response.text.strip().split('\n') + data = json.loads(lines[0][6:]) + assert "error" in data + assert data["error"]["message"] == "Session not initialized" + + # Unknown method + response = requests.post( + "http://localhost:8765/mcp", + json={"jsonrpc": "2.0", "method": "unknown/method", "params": {}, "id": 2}, + headers={"Accept": "application/json, text/event-stream"} + ) + assert response.status_code == 200 + + # Parse SSE response + lines = response.text.strip().split('\n') + data = json.loads(lines[0][6:]) + assert "error" in data + assert "Unknown method" in data["error"]["message"] diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..9e3d86d --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,270 @@ +"""Tests for MCP server functionality.""" + +import asyncio +import contextlib +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest + +from mcphawk import logger +from mcphawk.mcp_server.server import MCPHawkServer + + +@pytest.fixture +def test_db(tmp_path): + """Create a temporary test database.""" + db_path = tmp_path / "test_mcp.db" + logger.set_db_path(str(db_path)) + logger.init_db() + yield db_path + + +@pytest.fixture +def sample_logs(test_db): + """Create sample log entries.""" + test_messages = [ + { + "jsonrpc": "2.0", + "method": "tools/list", + "id": "req-1" + }, + { + "jsonrpc": "2.0", + "result": {"tools": ["query", "search"]}, + "id": "req-1" + }, + { + "jsonrpc": "2.0", + "method": "progress/update", + "params": {"progress": 50} + }, + { + "jsonrpc": "2.0", + "error": {"code": -32601, "message": "Method not found"}, + "id": "req-2" + } + ] + + log_ids = [] + for i, msg in enumerate(test_messages): + log_id = str(uuid.uuid4()) + log_ids.append(log_id) + entry = { + "log_id": log_id, + "timestamp": datetime.now(tz=timezone.utc), + "src_ip": "127.0.0.1", + "dst_ip": "127.0.0.1", + "src_port": 3000 + i, + "dst_port": 8000, + "direction": "unknown", + "message": json.dumps(msg), + "transport_type": "unknown" + } + logger.log_message(entry) + + return log_ids + + +class TestMCPServer: + """Test MCP server functionality.""" + + def test_server_initialization(self, test_db): + """Test server initializes correctly.""" + server = MCPHawkServer(str(test_db)) + assert server.mcp.name == "mcphawk-mcp" + + # Check that FastMCP instance was created + assert hasattr(server, 'mcp') + assert hasattr(server.mcp, 'tool') + + def test_list_tools(self, test_db): + """Test that tools are registered correctly.""" + server = MCPHawkServer(str(test_db)) + + # The FastMCP instance exposes tools through _tool_manager._tools + tool_names = list(server.mcp._tool_manager._tools.keys()) + + assert len(tool_names) == 5 + assert "query_traffic" in tool_names + assert "get_log" in tool_names + assert "search_traffic" in tool_names + assert "get_stats" in tool_names + assert "list_methods" in tool_names + + @pytest.mark.asyncio + async def test_call_tools(self, sample_logs): + """Test calling tools directly.""" + server = MCPHawkServer() + + # Test query_traffic + query_tool = server.mcp._tool_manager._tools["query_traffic"] + result = await query_tool.fn(limit=2, offset=0) + data = json.loads(result) + assert len(data) == 2 + assert all("log_id" in log for log in data) + + # Test get_log + get_log_tool = server.mcp._tool_manager._tools["get_log"] + result = await get_log_tool.fn(log_id=sample_logs[0]) + data = json.loads(result) + assert data["log_id"] == sample_logs[0] + assert "tools/list" in data["message"] + + # Test get_log with invalid ID + result = await get_log_tool.fn(log_id="invalid") + assert "No log found" in result + + # Test search_traffic + search_tool = server.mcp._tool_manager._tools["search_traffic"] + result = await search_tool.fn(search_term="tools/list") + data = json.loads(result) + assert len(data) == 1 + assert "tools/list" in data[0]["message"] + + # Test get_stats + stats_tool = server.mcp._tool_manager._tools["get_stats"] + result = await stats_tool.fn() + stats = json.loads(result) + assert stats["total_logs"] == 4 + assert stats["requests"] == 1 + assert stats["responses"] == 1 + assert stats["notifications"] == 1 + assert stats["errors"] == 1 + + # Test list_methods + methods_tool = server.mcp._tool_manager._tools["list_methods"] + result = await methods_tool.fn() + methods_data = json.loads(result) + # The result format is {"methods": [...], "count": 2} + assert methods_data["count"] == 2 + assert "tools/list" in methods_data["methods"] + assert "progress/update" in methods_data["methods"] + + @pytest.mark.asyncio + async def test_search_with_filters(self, sample_logs): + """Test search functionality with various filters.""" + server = MCPHawkServer() + + search_tool = server.mcp._tool_manager._tools["search_traffic"] + + # Test message type filter + result = await search_tool.fn( + search_term="jsonrpc", + message_type="notification" + ) + data = json.loads(result) + assert len(data) == 1 + assert "progress/update" in data[0]["message"] + + # Test traffic type filter + result = await search_tool.fn( + search_term="jsonrpc", + transport_type="unknown" + ) + data = json.loads(result) + # All 4 test messages match the search criteria + assert len(data) == 4 + assert all(log["transport_type"] == "unknown" for log in data) + + @pytest.mark.asyncio + async def test_error_handling(self, test_db): + """Test error handling in tool functions.""" + server = MCPHawkServer(str(test_db)) + + # The SDK handles parameter validation automatically + # So we'll test the actual error cases in the tool implementations + + get_log_tool = server.mcp._tool_manager._tools["get_log"] + + # Test with non-existent log ID + result = await get_log_tool.fn(log_id="non-existent-id") + assert "No log found with ID" in result + + @pytest.mark.asyncio + async def test_run_stdio(self): + """Test that run_stdio properly handles stdio transport.""" + server = MCPHawkServer() + + # Mock the run_stdio_async method to avoid actual stdio operations + with patch.object(server.mcp, 'run_stdio_async', new_callable=AsyncMock) as mock_run: + await server.run_stdio() + + # Verify run was called + mock_run.assert_called_once() + + @pytest.mark.asyncio + async def test_run_http(self, test_db): + """Test that run_http properly handles HTTP transport.""" + server = MCPHawkServer(str(test_db)) + + # Create test server config that immediately shuts down + with patch('uvicorn.Server') as mock_server_class: + mock_server_instance = AsyncMock() + mock_server_instance.serve = AsyncMock() + mock_server_class.return_value = mock_server_instance + + # Run the server in a task that we'll cancel + task = asyncio.create_task(server.run_http(port=8765)) + + # Give it a moment to set up + await asyncio.sleep(0.1) + + # Cancel the task + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + # Verify uvicorn was configured correctly + mock_server_class.assert_called_once() + config = mock_server_class.call_args[0][0] + assert config.host == "127.0.0.1" + assert config.port == 8765 + + @pytest.mark.asyncio + async def test_fastmcp_integration(self, sample_logs): + """Test that FastMCP handles requests correctly.""" + server = MCPHawkServer() + + # Test that we can access tool metadata + query_tool = server.mcp._tool_manager._tools["query_traffic"] + assert query_tool.description + # Tool has parameters but not exposed as input_schema + assert hasattr(query_tool, 'fn') + + # Test tool execution + result = await query_tool.fn(limit=1, offset=0) + data = json.loads(result) + assert len(data) == 1 + assert "log_id" in data[0] + + @pytest.mark.asyncio + async def test_notification_handling_concept(self): + """Test the concept of notification handling (SDK handles the actual protocol).""" + # The FastMCP SDK handles JSON-RPC protocol details including notifications + # This test verifies our understanding of the concept + + # In JSON-RPC 2.0: + # - Requests have an 'id' field and expect a response + # - Notifications have no 'id' field and should not receive a response + + notification = { + "jsonrpc": "2.0", + "method": "progress/update", + "params": {"progress": 50} + # Note: no 'id' field + } + + request = { + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 1 + } + + # Verify our test data structure + assert "id" not in notification + assert "id" in request + diff --git a/tests/test_mcp_stdio_integration.py b/tests/test_mcp_stdio_integration.py new file mode 100644 index 0000000..9387312 --- /dev/null +++ b/tests/test_mcp_stdio_integration.py @@ -0,0 +1,287 @@ +"""Integration tests for MCP stdio transport.""" + +import contextlib +import json +import subprocess +import sys +from datetime import datetime, timezone + +import pytest + +from mcphawk import logger + + +class TestMCPStdioIntegration: + """Test MCP server stdio transport with real subprocess communication.""" + + @pytest.fixture + def test_db(self, tmp_path, monkeypatch): + """Create a temporary test database and patch the logger to use it.""" + # Create temp database + db_path = tmp_path / "test_stdio.db" + + # Monkeypatch the logger module to use our test database + monkeypatch.setattr(logger, "DB_PATH", str(db_path)) + logger.init_db() + + # Add some test data + test_messages = [ + { + "jsonrpc": "2.0", + "method": "tools/list", + "id": "test-1" + }, + { + "jsonrpc": "2.0", + "result": {"tools": ["test"]}, + "id": "test-1" + } + ] + + for i, msg in enumerate(test_messages): + entry = { + "log_id": f"test-{i}", + "timestamp": datetime.now(tz=timezone.utc), + "src_ip": "127.0.0.1", + "dst_ip": "127.0.0.1", + "src_port": 3000, + "dst_port": 8000, + "direction": "unknown", + "message": json.dumps(msg), + "transport_type": "unknown" + } + logger.log_message(entry) + + return db_path + + def test_stdio_initialize_handshake(self, test_db, monkeypatch): + """Test proper MCP initialization handshake over stdio.""" + # For this test, we'll use the simpler approach without database dependency + # The key is to test the protocol handshake works correctly + + # Prepare all requests + requests = [ + { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + }, + "id": 1 + }, + { + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }, + { + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 2 + } + ] + + input_data = "\n".join(json.dumps(req) for req in requests) + "\n" + + # Start the MCP server + proc = subprocess.Popen( + [sys.executable, "-m", "mcphawk.cli", "mcp", "--transport", "stdio"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + stdout, stderr = proc.communicate(input=input_data, timeout=5) + + # Parse responses + responses = [] + for line in stdout.strip().split('\n'): + if line: + with contextlib.suppress(json.JSONDecodeError): + responses.append(json.loads(line)) + + # Should have at least 2 responses (no response for notification) + assert len(responses) >= 2, f"Expected at least 2 responses, got {len(responses)}" + + # Check initialize response + assert responses[0]["id"] == 1 + assert "result" in responses[0] + assert responses[0]["result"]["protocolVersion"] == "2024-11-05" + assert responses[0]["result"]["serverInfo"]["name"] == "mcphawk-mcp" + + # Check tools/list response + assert responses[1]["id"] == 2 + assert "result" in responses[1] + assert "tools" in responses[1]["result"] + assert len(responses[1]["result"]["tools"]) == 5 # We have 5 tools + + def test_stdio_basic_communication(self): + """Test basic stdio communication without select.""" + # Use communicate() for simpler testing + requests = [ + { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + }, + "id": 1 + }, + { + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }, + { + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 2 + } + ] + + input_data = "\n".join(json.dumps(req) for req in requests) + "\n" + + proc = subprocess.Popen( + [sys.executable, "-m", "mcphawk.cli", "mcp", "--transport", "stdio"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=0 + ) + + stdout, stderr = proc.communicate(input=input_data, timeout=5) + + # Parse responses + responses = [] + for line in stdout.strip().split('\n'): + if line: + with contextlib.suppress(json.JSONDecodeError): + responses.append(json.loads(line)) + + # Debug output + if not responses: + print(f"STDOUT: {stdout}") + print(f"STDERR: {stderr}") + + # Should have 2 responses (no response for notification) + assert len(responses) >= 2, f"Expected at least 2 responses, got {len(responses)}" + + # Check initialize response + assert responses[0]["id"] == 1 + assert responses[0]["result"]["protocolVersion"] == "2024-11-05" + + # Check tools/list response + assert responses[1]["id"] == 2 + assert "tools" in responses[1]["result"] + + def test_stdio_logging_to_stderr(self): + """Test that all logging goes to stderr, not stdout.""" + # Use communicate for simpler test + input_data = json.dumps({ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + }, + "id": 1 + }) + "\n" + + proc = subprocess.Popen( + [sys.executable, "-m", "mcphawk.cli", "mcp", "--transport", "stdio", "--debug"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + stdout, stderr = proc.communicate(input=input_data, timeout=5) + + # Response should be valid JSON (no log messages mixed in stdout) + response_line = stdout.strip() + response = json.loads(response_line) + assert response["id"] == 1 + assert "result" in response + + # Check stderr has log messages + assert "[MCPHawk]" in stderr + assert "Starting MCP server" in stderr + + @pytest.mark.parametrize("tool_name,args,expected_in_result", [ + ("get_stats", {}, "total"), + ("list_methods", {}, []), # Empty list if no data + ]) + def test_stdio_tool_calls_basic(self, tool_name, args, expected_in_result): + """Test basic tool calls that don't require specific data.""" + # Use communicate for simpler test + requests = [ + { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + }, + "id": 1 + }, + { + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + }, + { + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": args + }, + "id": 2 + } + ] + + input_data = "\n".join(json.dumps(req) for req in requests) + "\n" + + proc = subprocess.Popen( + [sys.executable, "-m", "mcphawk.cli", "mcp", "--transport", "stdio"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + stdout, stderr = proc.communicate(input=input_data, timeout=5) + + # Parse responses + responses = [] + for line in stdout.strip().split('\n'): + if line: + with contextlib.suppress(json.JSONDecodeError): + responses.append(json.loads(line)) + + # Should have 2 responses (init and tool call) + assert len(responses) >= 2 + + # Check tool response + tool_response = responses[-1] # Last response should be tool call + assert tool_response["id"] == 2 + assert "result" in tool_response + assert "content" in tool_response["result"] + + # Check content + content = tool_response["result"]["content"][0]["text"] + if isinstance(expected_in_result, str): + assert expected_in_result in content + else: + # For list_methods, just check it's valid JSON + json.loads(content) diff --git a/tests/test_sniffer.py b/tests/test_sniffer.py index d9ea081..1e45767 100644 --- a/tests/test_sniffer.py +++ b/tests/test_sniffer.py @@ -8,6 +8,7 @@ import pytest from scapy.all import Ether from scapy.layers.inet import IP, TCP +from scapy.layers.inet6 import IPv6 from scapy.packet import Raw from mcphawk.logger import init_db, set_db_path @@ -85,7 +86,7 @@ def test_packet_callback(clean_db, dummy_server): conn = sqlite3.connect(TEST_DB) cur = conn.cursor() - cur.execute("SELECT message FROM logs ORDER BY id DESC LIMIT 1;") + cur.execute("SELECT message FROM logs ORDER BY log_id DESC LIMIT 1;") logged_msg = cur.fetchone()[0] conn.close() @@ -166,3 +167,191 @@ def test_start_sniffer_auto_detect_flag(self, mock_sniff): mock_sniff.assert_called_once() call_kwargs = mock_sniff.call_args[1] assert call_kwargs["filter"] == "tcp" + + +class TestHTTPParsing: + """Test HTTP request/response parsing for MCP over HTTP.""" + + def setup_method(self): + """Reset global state before each test.""" + import mcphawk.sniffer + # Clear MCPHawk MCP ports + mcphawk.sniffer._mcphawk_mcp_ports.clear() + + @patch('mcphawk.sniffer.log_message') + @patch('mcphawk.sniffer._broadcast_in_any_loop') + def test_http_post_request_parsing(self, mock_broadcast, mock_log): + """Test parsing of HTTP POST request with JSON-RPC body.""" + http_request = ( + b'POST /mcp HTTP/1.1\r\n' + b'Host: localhost:8765\r\n' + b'Content-Type: application/json\r\n' + b'Content-Length: 89\r\n' + b'\r\n' + b'{"jsonrpc":"2.0","method":"initialize","params":{"protocolVersion":"2024-11-05"},"id":1}' + ) + + mock_pkt = MagicMock() + mock_pkt.haslayer.side_effect = lambda layer: layer in [Raw, IP, TCP] + mock_pkt.__getitem__.side_effect = lambda layer: { + Raw: MagicMock(load=http_request), + IP: MagicMock(src="127.0.0.1", dst="127.0.0.1"), + TCP: MagicMock(sport=54321, dport=8765) + }[layer] + + packet_callback(mock_pkt) + + # Verify the JSON-RPC body was extracted and logged + assert mock_log.called + logged_entry = mock_log.call_args[0][0] + assert logged_entry["message"] == '{"jsonrpc":"2.0","method":"initialize","params":{"protocolVersion":"2024-11-05"},"id":1}' + assert logged_entry["transport_type"] == "unknown" + assert logged_entry["src_port"] == 54321 + assert logged_entry["dst_port"] == 8765 + + @patch('mcphawk.sniffer.log_message') + @patch('mcphawk.sniffer._broadcast_in_any_loop') + def test_http_response_parsing(self, mock_broadcast, mock_log): + """Test parsing of HTTP response with JSON-RPC body.""" + http_response = ( + b'HTTP/1.1 200 OK\r\n' + b'Content-Type: application/json\r\n' + b'Content-Length: 50\r\n' + b'\r\n' + b'{"jsonrpc":"2.0","result":{"status":"ok"},"id":1}' + ) + + mock_pkt = MagicMock() + mock_pkt.haslayer.side_effect = lambda layer: layer in [Raw, IP, TCP] + mock_pkt.__getitem__.side_effect = lambda layer: { + Raw: MagicMock(load=http_response), + IP: MagicMock(src="127.0.0.1", dst="127.0.0.1"), + TCP: MagicMock(sport=8765, dport=54321) + }[layer] + + packet_callback(mock_pkt) + + # Verify the JSON-RPC body was extracted and logged + assert mock_log.called + logged_entry = mock_log.call_args[0][0] + assert logged_entry["message"] == '{"jsonrpc":"2.0","result":{"status":"ok"},"id":1}' + assert logged_entry["transport_type"] == "unknown" + assert logged_entry["src_port"] == 8765 + assert logged_entry["dst_port"] == 54321 + + @patch('mcphawk.sniffer.log_message') + @patch('mcphawk.sniffer._broadcast_in_any_loop') + def test_http_without_jsonrpc_ignored(self, mock_broadcast, mock_log): + """Test that HTTP requests without JSON-RPC content are ignored.""" + http_request = ( + b'POST /api/test HTTP/1.1\r\n' + b'Host: localhost:8765\r\n' + b'Content-Type: application/json\r\n' + b'\r\n' + b'{"data":"not json-rpc"}' + ) + + mock_pkt = MagicMock() + mock_pkt.haslayer.side_effect = lambda layer: layer in [Raw, IP, TCP] + mock_pkt.__getitem__.side_effect = lambda layer: { + Raw: MagicMock(load=http_request), + IP: MagicMock(src="127.0.0.1", dst="127.0.0.1"), + TCP: MagicMock(sport=54321, dport=8765) + }[layer] + + packet_callback(mock_pkt) + + # Should not log non-JSON-RPC content + assert not mock_log.called + + + @patch('mcphawk.sniffer.log_message') + @patch('mcphawk.sniffer._broadcast_in_any_loop') + def test_mcphawk_mcp_traffic_metadata(self, mock_broadcast, mock_log): + """Test that MCPHawk's own MCP traffic is tagged with metadata.""" + import mcphawk.sniffer + # Set up MCPHawk MCP ports for this test + mcphawk.sniffer._mcphawk_mcp_ports = {8765} + + http_request = ( + b'POST /mcp HTTP/1.1\r\n' + b'Host: localhost:8765\r\n' + b'Content-Type: application/json\r\n' + b'\r\n' + b'{"jsonrpc":"2.0","method":"test","id":1}' + ) + + mock_pkt = MagicMock() + # Note: haslayer should return True for IP but False for IPv6 + mock_pkt.haslayer.side_effect = lambda layer: { + Raw: True, + IP: True, + TCP: True, + IPv6: False + }.get(layer, False) + + mock_pkt.__getitem__.side_effect = lambda layer: { + Raw: MagicMock(load=http_request), + IP: MagicMock(src="127.0.0.1", dst="127.0.0.1"), + TCP: MagicMock(sport=54321, dport=8765) + }[layer] + + packet_callback(mock_pkt) + + # Verify metadata was added + assert mock_log.called + logged_entry = mock_log.call_args[0][0] + assert logged_entry["metadata"] == '{"source": "mcphawk-mcp"}' + + def test_state_isolation_between_tests(self): + """Test that state is properly isolated between tests.""" + import mcphawk.sniffer + # Verify state is clean at the start of the test + assert len(mcphawk.sniffer._mcphawk_mcp_ports) == 0 + + # Modify state + mcphawk.sniffer._mcphawk_mcp_ports.add(9999) + + # State will be cleaned up by setup_method before next test + + @patch('mcphawk.sniffer.log_message') + @patch('mcphawk.sniffer._broadcast_in_any_loop') + def test_http_sse_response_parsing(self, mock_broadcast, mock_log): + """Test parsing of Server-Sent Events (SSE) responses with JSON-RPC.""" + import json + + json_content = json.dumps({ + "jsonrpc": "2.0", + "result": { + "tools": [ + {"name": "get_stats", "description": "Get traffic stats"} + ] + }, + "id": 2 + }) + + sse_response = ( + f"HTTP/1.1 200 OK\r\n" + f"Content-Type: text/event-stream\r\n" + f"Cache-Control: no-cache\r\n" + f"\r\n" + f"data: {json_content}\n\n" + ).encode() + + mock_pkt = MagicMock() + mock_pkt.haslayer.side_effect = lambda layer: layer in [Raw, IP, TCP] + mock_pkt.__getitem__.side_effect = lambda layer: { + Raw: MagicMock(load=sse_response), + IP: MagicMock(src="127.0.0.1", dst="127.0.0.1"), + TCP: MagicMock(sport=8765, dport=54321) + }[layer] + + packet_callback(mock_pkt) + + # Verify the JSON-RPC was extracted from SSE format and logged + assert mock_log.called + logged_entry = mock_log.call_args[0][0] + assert logged_entry["message"] == json_content + assert logged_entry["transport_type"] == "unknown" + assert logged_entry["src_port"] == 8765 + assert logged_entry["dst_port"] == 54321 diff --git a/tests/test_sniffer_traffic_type.py b/tests/test_sniffer_traffic_type.py index bda916f..ddbfc2d 100644 --- a/tests/test_sniffer_traffic_type.py +++ b/tests/test_sniffer_traffic_type.py @@ -1,4 +1,4 @@ -"""Test that sniffer properly sets traffic_type for captured packets.""" +"""Test that sniffer properly sets transport_type for captured packets.""" import json import time @@ -13,15 +13,15 @@ @pytest.fixture def test_db(tmp_path): """Create a test database.""" - db_path = tmp_path / "test_sniffer_traffic_type.db" + db_path = tmp_path / "test_sniffer_transport_type.db" set_db_path(str(db_path)) init_db() yield db_path clear_logs() -def test_sniffer_tcp_traffic_type(test_db): - """Test that sniffer marks TCP JSON-RPC traffic with traffic_type='TCP'.""" +def test_sniffer_tcp_transport_type(test_db): + """Test that sniffer marks TCP JSON-RPC traffic with transport_type='TCP'.""" # Create a mock TCP packet with JSON-RPC content json_rpc = json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}) @@ -42,40 +42,11 @@ def test_sniffer_tcp_traffic_type(test_db): # Check the logged entry logs = fetch_logs(limit=1) assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/Direct" + assert logs[0]["transport_type"] == "unknown" assert logs[0]["src_port"] == 12345 assert logs[0]["dst_port"] == 54321 -def test_sniffer_websocket_traffic_type(test_db): - """Test that sniffer marks WebSocket traffic with traffic_type='WS'.""" - # Create a WebSocket text frame (0x81 = FIN + TEXT) - json_rpc = json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}) - - # WebSocket frame: FIN=1, RSV=0, opcode=1 (text), unmasked - frame_header = bytes([0x81, len(json_rpc)]) # Text frame, payload length - ws_frame = frame_header + json_rpc.encode() - - # Create packet layers - ip = IP(src="127.0.0.1", dst="127.0.0.1") - tcp = TCP(sport=8765, dport=54321) - raw = Raw(load=ws_frame) - - # Construct the packet - pkt = ip / tcp / raw - - # Process the packet - packet_callback(pkt) - - # Give it a moment to process - time.sleep(0.1) - - # Check the logged entry - logs = fetch_logs(limit=1) - assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/WS" - assert logs[0]["src_port"] == 8765 - assert logs[0]["dst_port"] == 54321 def test_sniffer_non_jsonrpc_ignored(test_db): diff --git a/tests/test_tcp_reassembly.py b/tests/test_tcp_reassembly.py new file mode 100644 index 0000000..bcf11fd --- /dev/null +++ b/tests/test_tcp_reassembly.py @@ -0,0 +1,226 @@ +"""Tests for TCP stream reassembly and SSE parsing.""" + +from scapy.all import IP, TCP, Ether, Raw + +from mcphawk.tcp_reassembly import HTTPStream, TCPStreamReassembler + + +class TestHTTPStream: + """Test HTTPStream parsing.""" + + def test_extract_sse_messages_with_lf_separator(self): + """Test SSE message extraction with LF separator (\n\n).""" + stream = HTTPStream() + stream.is_sse = True + stream.is_chunked = False + stream.response_headers = {"content-type": "text/event-stream"} + + # Add SSE data with LF separator + stream.buffer = b'event: message\ndata: {"jsonrpc":"2.0","id":1,"result":{}}\n\n' + + messages = stream.extract_sse_messages() + assert len(messages) == 1 + assert messages[0] == '{"jsonrpc":"2.0","id":1,"result":{}}' + + def test_extract_sse_messages_with_crlf_separator(self): + """Test SSE message extraction with CRLF separator (\r\n\r\n).""" + stream = HTTPStream() + stream.is_sse = True + stream.is_chunked = False + stream.response_headers = {"content-type": "text/event-stream"} + + # Add SSE data with CRLF separator (common with HTTP) + stream.buffer = b'event: message\r\ndata: {"jsonrpc":"2.0","id":1,"result":{}}\r\n\r\n' + + messages = stream.extract_sse_messages() + assert len(messages) == 1 + assert messages[0] == '{"jsonrpc":"2.0","id":1,"result":{}}' + + def test_extract_sse_messages_mixed_separators(self): + """Test SSE message extraction with mixed separators.""" + stream = HTTPStream() + stream.is_sse = True + stream.is_chunked = False + stream.response_headers = {"content-type": "text/event-stream"} + + # Multiple messages with different separators + stream.buffer = ( + b'event: message\ndata: {"id":1}\n\n' # LF separator + b'event: message\r\ndata: {"id":2}\r\n\r\n' # CRLF separator + ) + + messages = stream.extract_sse_messages() + assert len(messages) == 2 + assert messages[0] == '{"id":1}' + assert messages[1] == '{"id":2}' + + def test_extract_sse_messages_with_chunked_encoding(self): + """Test SSE message extraction from chunked transfer encoding.""" + stream = HTTPStream() + stream.is_sse = True + stream.is_chunked = True + stream.response_headers = { + "content-type": "text/event-stream", + "transfer-encoding": "chunked" + } + + # Chunked data: size in hex, CRLF, data, CRLF + sse_data = b'event: message\r\ndata: {"jsonrpc":"2.0","id":1,"result":{}}\r\n\r\n' + chunk_size = hex(len(sse_data))[2:].encode() + stream.buffer = chunk_size + b'\r\n' + sse_data + b'\r\n0\r\n\r\n' + + messages = stream.extract_sse_messages() + assert len(messages) == 1 + assert messages[0] == '{"jsonrpc":"2.0","id":1,"result":{}}' + + def test_extract_sse_messages_incomplete_chunk(self): + """Test SSE message extraction with incomplete chunked data.""" + stream = HTTPStream() + stream.is_sse = True + stream.is_chunked = True + stream.response_headers = { + "content-type": "text/event-stream", + "transfer-encoding": "chunked" + } + + # Incomplete chunk - missing data + stream.buffer = b'10\r\nonly5bytes' + + messages = stream.extract_sse_messages() + assert len(messages) == 0 # Should wait for more data + + def test_parse_http_response_headers(self): + """Test HTTP response header parsing.""" + stream = HTTPStream() + + response = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + stream.add_response_data(response) + + assert stream.response_headers is not None + assert stream.is_sse is True + assert stream.is_chunked is True + assert stream.response_headers["content-type"] == "text/event-stream" + assert stream.response_headers["transfer-encoding"] == "chunked" + + +class TestTCPStreamReassembler: + """Test TCP stream reassembly.""" + + def test_http_request_response_flow(self): + """Test complete HTTP request/response flow.""" + reassembler = TCPStreamReassembler() + + # HTTP request + req_pkt = Ether() / IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=12345, dport=8765) / Raw( + b"POST /mcp HTTP/1.1\r\n" + b"Content-Type: application/json\r\n" + b"\r\n" + b'{"jsonrpc":"2.0","method":"test","id":1}' + ) + + messages = reassembler.process_packet(req_pkt) + assert len(messages) == 0 # Requests are not returned as messages + + # HTTP response with SSE + resp_pkt = Ether() / IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=8765, dport=12345) / Raw( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"\r\n" + b'event: message\r\n' + b'data: {"jsonrpc":"2.0","id":1,"result":"ok"}\r\n' + b'\r\n' + ) + + messages = reassembler.process_packet(resp_pkt) + assert len(messages) == 1 + assert messages[0]["message"] == '{"jsonrpc":"2.0","id":1,"result":"ok"}' + assert messages[0]["type"] == "sse_response" + + def test_chunked_sse_response_reassembly(self): + """Test reassembly of chunked SSE response across multiple packets.""" + reassembler = TCPStreamReassembler() + + # First, send request to establish stream + req_pkt = Ether() / IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=12345, dport=8765) / Raw( + b"POST /mcp HTTP/1.1\r\n\r\n" + ) + reassembler.process_packet(req_pkt) + + # Response headers + resp_hdr = Ether() / IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=8765, dport=12345) / Raw( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + messages = reassembler.process_packet(resp_hdr) + assert len(messages) == 0 # Just headers, no data yet + + # Chunked SSE data + sse_data = b'event: message\r\ndata: {"jsonrpc":"2.0","id":1,"result":"test"}\r\n\r\n' + chunk_size = hex(len(sse_data))[2:].encode() + + data_pkt = Ether() / IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=8765, dport=12345) / Raw( + chunk_size + b'\r\n' + sse_data + b'\r\n' + ) + + messages = reassembler.process_packet(data_pkt) + assert len(messages) == 1 + assert messages[0]["message"] == '{"jsonrpc":"2.0","id":1,"result":"test"}' + + # End chunk + end_pkt = Ether() / IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=8765, dport=12345) / Raw( + b"0\r\n\r\n" + ) + + messages = reassembler.process_packet(end_pkt) + assert len(messages) == 0 # End chunk doesn't produce messages + + def test_multiple_streams(self): + """Test handling multiple concurrent TCP streams.""" + reassembler = TCPStreamReassembler() + + # Stream 1: Client A -> Server + req1 = Ether() / IP(src="10.0.0.1", dst="10.0.0.2") / TCP(sport=1111, dport=8765) / Raw( + b"POST /api HTTP/1.1\r\n\r\n" + ) + reassembler.process_packet(req1) + + # Stream 2: Client B -> Server + req2 = Ether() / IP(src="10.0.0.3", dst="10.0.0.2") / TCP(sport=2222, dport=8765) / Raw( + b"POST /api HTTP/1.1\r\n\r\n" + ) + reassembler.process_packet(req2) + + # Response to Client A + resp1 = Ether() / IP(src="10.0.0.2", dst="10.0.0.1") / TCP(sport=8765, dport=1111) / Raw( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"\r\n" + b'data: {"client":"A"}\r\n\r\n' + ) + + messages = reassembler.process_packet(resp1) + assert len(messages) == 1 + assert messages[0]["message"] == '{"client":"A"}' + assert messages[0]["dst_port"] == 1111 + + # Response to Client B + resp2 = Ether() / IP(src="10.0.0.2", dst="10.0.0.3") / TCP(sport=8765, dport=2222) / Raw( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"\r\n" + b'data: {"client":"B"}\r\n\r\n' + ) + + messages = reassembler.process_packet(resp2) + assert len(messages) == 1 + assert messages[0]["message"] == '{"client":"B"}' + assert messages[0]["dst_port"] == 2222 diff --git a/tests/test_traffic_type.py b/tests/test_traffic_type.py index a035a96..b74b299 100644 --- a/tests/test_traffic_type.py +++ b/tests/test_traffic_type.py @@ -1,5 +1,6 @@ -"""Test traffic_type field is properly set for TCP and WebSocket traffic.""" +"""Test transport_type field is properly set for TCP traffic.""" import json +import uuid from datetime import datetime, timezone import pytest @@ -10,16 +11,17 @@ @pytest.fixture def test_db(tmp_path): """Create a test database.""" - db_path = tmp_path / "test_traffic_type.db" + db_path = tmp_path / "test_transport_type.db" set_db_path(str(db_path)) init_db() yield db_path clear_logs() -def test_tcp_traffic_type(test_db): - """Test that TCP traffic is marked with traffic_type='TCP'.""" +def test_tcp_transport_type(test_db): + """Test that TCP traffic is marked with transport_type='TCP'.""" entry = { + "log_id": str(uuid.uuid4()), "timestamp": datetime.now(tz=timezone.utc), "src_ip": "127.0.0.1", "src_port": 12345, @@ -27,39 +29,21 @@ def test_tcp_traffic_type(test_db): "dst_port": 54321, "direction": "outgoing", "message": json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}), - "traffic_type": "TCP/Direct", + "transport_type": "unknown", } log_message(entry) logs = fetch_logs(limit=1) assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/Direct" + assert logs[0]["transport_type"] == "unknown" -def test_websocket_traffic_type(test_db): - """Test that WebSocket traffic is marked with traffic_type='WS'.""" - entry = { - "timestamp": datetime.now(tz=timezone.utc), - "src_ip": "127.0.0.1", - "src_port": 8765, - "dst_ip": "127.0.0.1", - "dst_port": 54321, - "direction": "outgoing", - "message": json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}), - "traffic_type": "TCP/WS", - } - - log_message(entry) - - logs = fetch_logs(limit=1) - assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/WS" - -def test_unknown_traffic_type(test_db): - """Test that unknown traffic is marked with traffic_type='N/A'.""" +def test_unknown_transport_type(test_db): + """Test that unknown traffic is marked with transport_type='N/A'.""" entry = { + "log_id": str(uuid.uuid4()), "timestamp": datetime.now(tz=timezone.utc), "src_ip": "127.0.0.1", "src_port": 9999, @@ -67,30 +51,31 @@ def test_unknown_traffic_type(test_db): "dst_port": 54321, "direction": "outgoing", "message": json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}), - # Omit traffic_type to test default + # Omit transport_type to test default } log_message(entry) logs = fetch_logs(limit=1) assert len(logs) == 1 - assert logs[0]["traffic_type"] == "N/A" + assert logs[0]["transport_type"] == "unknown" -def test_legacy_entries_without_traffic_type(test_db): - """Test that we can handle legacy entries without traffic_type column.""" +def test_legacy_entries_without_transport_type(test_db): + """Test that we can handle legacy entries without transport_type column.""" # This tests the backward compatibility in fetch_logs import sqlite3 - # Insert a row without traffic_type using direct SQL + # Insert a row without transport_type using direct SQL conn = sqlite3.connect(str(test_db)) cur = conn.cursor() cur.execute( """ - INSERT INTO logs (timestamp, src_ip, dst_ip, src_port, dst_port, direction, message) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( + str(uuid.uuid4()), datetime.now(tz=timezone.utc).isoformat(), "127.0.0.1", "127.0.0.1", @@ -105,5 +90,5 @@ def test_legacy_entries_without_traffic_type(test_db): logs = fetch_logs(limit=1) assert len(logs) == 1 - # Should get 'N/A' for entries without traffic_type - assert logs[0]["traffic_type"] == "N/A" + # Should get 'N/A' for entries without transport_type + assert logs[0]["transport_type"] == "unknown" diff --git a/tests/test_transport_detector.py b/tests/test_transport_detector.py new file mode 100644 index 0000000..8082fce --- /dev/null +++ b/tests/test_transport_detector.py @@ -0,0 +1,179 @@ +"""Test transport detection logic.""" + + +from mcphawk.transport_detector import ( + MCPTransport, + TransportTracker, + detect_transport_from_http, + extract_endpoint_from_sse, +) + + +class TestTransportDetection: + """Test MCP transport detection.""" + + def test_streamable_http_dual_accept(self): + """Test Streamable HTTP detection with dual accept headers.""" + transport = detect_transport_from_http( + method="POST", + path="/mcp", + headers={ + "accept": "application/json, text/event-stream", + "content-type": "application/json" + } + ) + assert transport == MCPTransport.STREAMABLE_HTTP + + def test_post_without_accept_headers(self): + """Test that POST without Accept headers returns UNKNOWN. + + POST to /mcp without proper Accept headers cannot be definitively + identified as Streamable HTTP. It could be HTTP+SSE posting to + the endpoint URL. + """ + transport = detect_transport_from_http( + method="POST", + path="/mcp", + headers={"content-type": "application/json"} + ) + assert transport == MCPTransport.UNKNOWN + + def test_streamable_http_sse_response_to_post(self): + """Test Streamable HTTP detection when SSE response to POST.""" + transport = detect_transport_from_http( + method="POST", + path="/api", + headers={"content-type": "application/json"}, + is_sse_response=True + ) + assert transport == MCPTransport.STREAMABLE_HTTP + + def test_http_sse_get_with_accept(self): + """Test HTTP+SSE detection with GET and SSE accept header.""" + transport = detect_transport_from_http( + method="GET", + path="/sse", + headers={"accept": "text/event-stream"} + ) + assert transport == MCPTransport.HTTP_SSE + + def test_http_sse_with_endpoint_event(self): + """Test HTTP+SSE detection with endpoint event.""" + transport = detect_transport_from_http( + method="GET", + path="/sse", + headers={"accept": "text/event-stream"}, + response_contains_endpoint_event=True + ) + assert transport == MCPTransport.HTTP_SSE + + def test_http_sse_response_to_get(self): + """Test HTTP+SSE detection when SSE response to GET.""" + transport = detect_transport_from_http( + method="GET", + path="/api", + headers={}, + is_sse_response=True + ) + assert transport == MCPTransport.HTTP_SSE + + def test_unknown_regular_http(self): + """Test unknown transport for regular HTTP.""" + transport = detect_transport_from_http( + method="POST", + path="/api", + headers={"content-type": "application/json"} + ) + assert transport == MCPTransport.UNKNOWN + + def test_unknown_no_clear_indicators(self): + """Test unknown transport when no clear indicators.""" + transport = detect_transport_from_http( + method="GET", + path="/api", + headers={"accept": "application/json"} + ) + assert transport == MCPTransport.UNKNOWN + + +class TestEndpointExtraction: + """Test SSE endpoint event extraction.""" + + def test_extract_endpoint_url(self): + """Test extracting endpoint URL from SSE data.""" + sse_data = 'event: endpoint\ndata: {"url": "http://localhost:8080/messages"}\n\n' + url = extract_endpoint_from_sse(sse_data) + assert url == "http://localhost:8080/messages" + + def test_extract_endpoint_url_with_crlf(self): + """Test extracting endpoint URL with CRLF line endings.""" + sse_data = 'event: endpoint\r\ndata: {"url": "http://localhost:8080/messages"}\r\n\r\n' + url = extract_endpoint_from_sse(sse_data) + assert url == "http://localhost:8080/messages" + + def test_extract_endpoint_url_not_found(self): + """Test when no endpoint event is present.""" + sse_data = 'event: message\ndata: {"jsonrpc": "2.0", "id": 1}\n\n' + url = extract_endpoint_from_sse(sse_data) + assert url is None + + def test_extract_endpoint_invalid_json(self): + """Test when endpoint data is invalid JSON.""" + sse_data = 'event: endpoint\ndata: invalid json\n\n' + url = extract_endpoint_from_sse(sse_data) + assert url is None + + +class TestTransportTracker: + """Test transport tracking across connections.""" + + def test_update_and_get_transport(self): + """Test updating and retrieving transport for a connection.""" + tracker = TransportTracker() + + # Update transport + tracker.update_transport( + "127.0.0.1", 12345, "127.0.0.1", 8080, + MCPTransport.STREAMABLE_HTTP + ) + + # Should work in both directions + assert tracker.get_transport("127.0.0.1", 12345, "127.0.0.1", 8080) == MCPTransport.STREAMABLE_HTTP + assert tracker.get_transport("127.0.0.1", 8080, "127.0.0.1", 12345) == MCPTransport.STREAMABLE_HTTP + + def test_get_unknown_transport(self): + """Test getting transport for unknown connection.""" + tracker = TransportTracker() + assert tracker.get_transport("127.0.0.1", 12345, "127.0.0.1", 8080) == MCPTransport.UNKNOWN + + def test_store_endpoint_url(self): + """Test storing endpoint URL for HTTP+SSE.""" + tracker = TransportTracker() + + tracker.store_endpoint_url( + "127.0.0.1", 12345, "127.0.0.1", 8080, + "http://localhost:8080/messages" + ) + + # Verify it was stored + key = ("127.0.0.1", 12345, "127.0.0.1", 8080) + assert tracker.endpoint_urls[key] == "http://localhost:8080/messages" + + def test_dont_update_unknown_transport(self): + """Test that UNKNOWN transport doesn't overwrite existing.""" + tracker = TransportTracker() + + # Set a known transport + tracker.update_transport( + "127.0.0.1", 12345, "127.0.0.1", 8080, + MCPTransport.STREAMABLE_HTTP + ) + + # Try to update with UNKNOWN (should not change) + tracker.update_transport( + "127.0.0.1", 12345, "127.0.0.1", 8080, + MCPTransport.UNKNOWN + ) + + # Should still be STREAMABLE_HTTP + assert tracker.get_transport("127.0.0.1", 12345, "127.0.0.1", 8080) == MCPTransport.STREAMABLE_HTTP diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..410462f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,165 @@ +"""Tests for utility functions.""" + +import json + +from mcphawk.utils import get_message_type, get_method_name, parse_message + + +class TestParseMessage: + """Test parse_message function.""" + + def test_parse_valid_json_string(self): + """Test parsing valid JSON string.""" + message = '{"jsonrpc": "2.0", "method": "test"}' + result = parse_message(message) + assert result == {"jsonrpc": "2.0", "method": "test"} + + def test_parse_invalid_json_string(self): + """Test parsing invalid JSON string.""" + message = '{"invalid": json}' + result = parse_message(message) + assert result is None + + def test_parse_non_string(self): + """Test parsing non-string input.""" + message = {"already": "parsed"} + result = parse_message(message) + assert result == {"already": "parsed"} + + def test_parse_empty_string(self): + """Test parsing empty string.""" + result = parse_message("") + assert result is None + + +class TestGetMessageType: + """Test get_message_type function.""" + + def test_request_message(self): + """Test identifying request messages.""" + message = json.dumps({ + "jsonrpc": "2.0", + "method": "tools/list", + "id": "123" + }) + assert get_message_type(message) == "request" + + def test_response_message(self): + """Test identifying response messages.""" + message = json.dumps({ + "jsonrpc": "2.0", + "result": {"tools": []}, + "id": "123" + }) + assert get_message_type(message) == "response" + + def test_notification_message(self): + """Test identifying notification messages.""" + message = json.dumps({ + "jsonrpc": "2.0", + "method": "progress/update" + }) + assert get_message_type(message) == "notification" + + def test_error_message(self): + """Test identifying error messages.""" + message = json.dumps({ + "jsonrpc": "2.0", + "error": { + "code": -32600, + "message": "Invalid Request" + }, + "id": "123" + }) + assert get_message_type(message) == "error" + + def test_unknown_message_no_jsonrpc(self): + """Test message without jsonrpc field.""" + message = json.dumps({ + "method": "test", + "id": "123" + }) + assert get_message_type(message) == "unknown" + + def test_unknown_message_wrong_version(self): + """Test message with wrong JSON-RPC version.""" + message = json.dumps({ + "jsonrpc": "1.0", + "method": "test", + "id": "123" + }) + assert get_message_type(message) == "unknown" + + def test_unknown_message_invalid_json(self): + """Test invalid JSON.""" + message = "not json" + assert get_message_type(message) == "unknown" + + def test_edge_case_null_id(self): + """Test request with null id (valid in JSON-RPC 2.0).""" + message = json.dumps({ + "jsonrpc": "2.0", + "method": "test", + "id": None + }) + assert get_message_type(message) == "request" + + def test_edge_case_error_without_id(self): + """Test error without id (should be unknown).""" + message = json.dumps({ + "jsonrpc": "2.0", + "error": {"code": -32600, "message": "Error"} + }) + assert get_message_type(message) == "unknown" + + def test_edge_case_result_without_id(self): + """Test result without id (should be unknown).""" + message = json.dumps({ + "jsonrpc": "2.0", + "result": "test" + }) + assert get_message_type(message) == "unknown" + + +class TestGetMethodName: + """Test get_method_name function.""" + + def test_get_method_from_request(self): + """Test extracting method from request.""" + message = json.dumps({ + "jsonrpc": "2.0", + "method": "tools/list", + "id": "123" + }) + assert get_method_name(message) == "tools/list" + + def test_get_method_from_notification(self): + """Test extracting method from notification.""" + message = json.dumps({ + "jsonrpc": "2.0", + "method": "progress/update" + }) + assert get_method_name(message) == "progress/update" + + def test_get_method_from_response(self): + """Test extracting method from response (should be None).""" + message = json.dumps({ + "jsonrpc": "2.0", + "result": {"tools": []}, + "id": "123" + }) + assert get_method_name(message) is None + + def test_get_method_from_error(self): + """Test extracting method from error (should be None).""" + message = json.dumps({ + "jsonrpc": "2.0", + "error": {"code": -32600, "message": "Error"}, + "id": "123" + }) + assert get_method_name(message) is None + + def test_get_method_from_invalid_json(self): + """Test extracting method from invalid JSON.""" + message = "not json" + assert get_method_name(message) is None diff --git a/tests/test_web.py b/tests/test_web.py index fc9b4a0..f611812 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -8,6 +8,7 @@ import os import tempfile +import uuid from datetime import datetime, timezone import pytest @@ -55,22 +56,33 @@ def clean_db(setup_test_db): yield +def create_test_log(message, **kwargs): + """Helper to create a log entry with required log_id.""" + entry = { + "log_id": str(uuid.uuid4()), + "timestamp": datetime.now(timezone.utc), + "src_ip": "127.0.0.1", + "dst_ip": "127.0.0.1", + "message": message, + } + entry.update(kwargs) + return entry + + def test_get_logs_limit(setup_test_db): """ Ensure /logs returns valid JSON and respects the limit parameter. """ # Insert two logs for testing - log_message( - { - "timestamp": datetime.now(timezone.utc), + log_message({ + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "127.0.0.1", "dst_ip": "127.0.0.1", "message": '{"jsonrpc":"2.0","method":"ping"}', } ) - log_message( - { - "timestamp": datetime.now(timezone.utc), + log_message({ + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "127.0.0.1", "dst_ip": "127.0.0.1", "message": '{"jsonrpc":"2.0","method":"pong"}', @@ -92,17 +104,15 @@ def test_get_logs_multiple(setup_test_db): Ensure /logs can return multiple rows when more logs are inserted. """ # Insert multiple logs - log_message( - { - "timestamp": datetime.now(timezone.utc), + log_message({ + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "127.0.0.1", "dst_ip": "127.0.0.1", "message": '{"jsonrpc":"2.0","method":"ping"}', } ) - log_message( - { - "timestamp": datetime.now(timezone.utc), + log_message({ + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "127.0.0.1", "dst_ip": "127.0.0.1", "message": '{"jsonrpc":"2.0","method":"pong"}', @@ -153,7 +163,7 @@ def test_logs_persist_across_requests(setup_test_db): """Test that logs persist between different API requests.""" # Add a log log_message({ - "timestamp": datetime.now(timezone.utc), + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "192.168.1.1", "dst_ip": "192.168.1.2", "src_port": 12345, @@ -184,7 +194,7 @@ def test_logs_order_newest_first(setup_test_db): # Add logs with small delays to ensure different timestamps for i in range(3): log_message({ - "timestamp": datetime.now(timezone.utc), + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "127.0.0.1", "dst_ip": "127.0.0.1", "src_port": 12345, @@ -210,7 +220,7 @@ def test_logs_default_limit(setup_test_db): # Add 60 logs for i in range(60): log_message({ - "timestamp": datetime.now(timezone.utc), + "log_id": str(uuid.uuid4()),"timestamp": datetime.now(timezone.utc), "src_ip": "127.0.0.1", "dst_ip": "127.0.0.1", "src_port": 12345, @@ -229,6 +239,7 @@ def test_logs_default_limit(setup_test_db): def test_log_fields_preserved_in_api(setup_test_db): """Test that all log fields are preserved through the API.""" test_entry = { + "log_id": str(uuid.uuid4()), "timestamp": datetime.now(timezone.utc), "src_ip": "10.0.0.1", "dst_ip": "10.0.0.2", diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 80a098e..89bc655 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -132,7 +132,7 @@ def test_run_web_with_sniffer(): ) # Verify sniffer was started - mock_start_sniffer.assert_called_once_with("tcp port 3000", False, True) + mock_start_sniffer.assert_called_once_with("tcp port 3000", False, True, None, None) # Verify uvicorn was started with correct params mock_uvicorn.assert_called_once() @@ -172,3 +172,32 @@ def test_run_web_sniffer_without_filter(): # Test that ValueError is raised when sniffer=True but no filter_expr with pytest.raises(ValueError, match="filter_expr is required"): run_web(sniffer=True, filter_expr=None) + + +def test_status_endpoint(client): + """Test /status endpoint returns MCP server status.""" + response = client.get("/status") + assert response.status_code == 200 + data = response.json() + assert "with_mcp" in data + assert isinstance(data["with_mcp"], bool) + + +def test_status_endpoint_with_mcp_enabled(): + """Test /status endpoint when MCP is enabled.""" + import mcphawk.web.server + from mcphawk.web.server import app + + # Set MCP flag + original_value = mcphawk.web.server._with_mcp + mcphawk.web.server._with_mcp = True + + try: + with TestClient(app) as client: + response = client.get("/status") + assert response.status_code == 200 + data = response.json() + assert data["with_mcp"] is True + finally: + # Restore original value + mcphawk.web.server._with_mcp = original_value diff --git a/tests/test_websocket_integration.py b/tests/test_websocket_integration.py deleted file mode 100644 index 4d1602f..0000000 --- a/tests/test_websocket_integration.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Integration tests for WebSocket MCP capture.""" - -import contextlib -import json -import os -import sqlite3 -import tempfile -from pathlib import Path -from unittest.mock import Mock - -import pytest -from scapy.all import IP, TCP, Raw - -from mcphawk import logger -from mcphawk.sniffer import packet_callback -from mcphawk.ws_reassembly import ws_buffers - - -@pytest.fixture -def temp_db(): - """Create a temporary database for testing.""" - fd, db_path = tempfile.mkstemp(suffix=".db") - os.close(fd) - - # Set up the database - original_path = getattr(logger, 'DB_PATH', None) - logger.set_db_path(db_path) - logger.init_db() - - # Also update the module variable directly - import mcphawk.logger as logger_module - logger_module.DB_PATH = Path(db_path) - - yield db_path - - # Cleanup - with contextlib.suppress(OSError): - os.unlink(db_path) - - # Restore original path if it existed - if original_path: - logger.DB_PATH = original_path - logger_module.DB_PATH = original_path - - -def create_ws_text_frame(text): - """Create a WebSocket text frame.""" - frame = bytearray() - frame.append(0x81) # FIN=1, opcode=1 (text) - payload = text.encode('utf-8') - - if len(payload) < 126: - frame.append(len(payload)) - elif len(payload) < 65536: - frame.append(126) - frame.extend(len(payload).to_bytes(2, 'big')) - else: - frame.append(127) - frame.extend(len(payload).to_bytes(8, 'big')) - - frame.extend(payload) - return bytes(frame) - - -def create_mock_packet(src_ip, src_port, dst_ip, dst_port, payload): - """Create a mock Scapy packet.""" - mock_packet = Mock() - - # Mock Raw layer - mock_raw = Mock() - mock_raw.load = payload - - # Mock TCP layer - mock_tcp = Mock() - mock_tcp.sport = src_port - mock_tcp.dport = dst_port - - # Mock IP layer - mock_ip = Mock() - mock_ip.src = src_ip - mock_ip.dst = dst_ip - - # Setup haslayer - def haslayer(layer_type): - return layer_type in (Raw, TCP, IP) - - mock_packet.haslayer = haslayer - - # Setup getitem - def getitem(self, layer_type): - if layer_type == Raw: - return mock_raw - elif layer_type == TCP: - return mock_tcp - elif layer_type == IP: - return mock_ip - raise KeyError(f"No layer {layer_type}") - - mock_packet.__getitem__ = getitem - - return mock_packet - - -def test_websocket_capture_simple(temp_db): - """Test basic WebSocket frame capture.""" - # Create WebSocket frames - msg1 = {"jsonrpc": "2.0", "method": "initialize", "id": 1} - msg2 = {"jsonrpc": "2.0", "result": {"status": "ok"}, "id": 1} - - frame1 = create_ws_text_frame(json.dumps(msg1)) - frame2 = create_ws_text_frame(json.dumps(msg2)) - - # Create mock packets - packet1 = create_mock_packet("127.0.0.1", 12345, "127.0.0.1", 8765, frame1) - packet2 = create_mock_packet("127.0.0.1", 8765, "127.0.0.1", 12345, frame2) - - # Process packets - packet_callback(packet1) - packet_callback(packet2) - - # Check database - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - cursor.execute("SELECT message FROM logs ORDER BY timestamp") - rows = cursor.fetchall() - conn.close() - - assert len(rows) == 2 - - # Check first message - msg = json.loads(rows[0][0]) - assert msg["method"] == "initialize" - assert msg["id"] == 1 - - # Check second message - msg = json.loads(rows[1][0]) - assert msg["result"]["status"] == "ok" - assert msg["id"] == 1 - - -def test_websocket_capture_with_notification(temp_db): - """Test WebSocket capture with notifications (no ID).""" - # Create frames - request = {"jsonrpc": "2.0", "method": "ping", "id": 1} - response = {"jsonrpc": "2.0", "result": "pong", "id": 1} - notification = {"jsonrpc": "2.0", "method": "status/update", "params": {"value": 42}} - - frames = [ - create_ws_text_frame(json.dumps(request)), - create_ws_text_frame(json.dumps(response)), - create_ws_text_frame(json.dumps(notification)), - ] - - # Process all frames - for frame in frames: - packet = create_mock_packet("127.0.0.1", 12345, "127.0.0.1", 8765, frame) - packet_callback(packet) - - # Check database - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - cursor.execute("SELECT message FROM logs") - rows = cursor.fetchall() - conn.close() - - assert len(rows) == 3 - - messages = [json.loads(row[0]) for row in rows] - - # Check we have the notification - notifications = [msg for msg in messages if "method" in msg and "id" not in msg] - assert len(notifications) == 1 - assert notifications[0]["method"] == "status/update" - - -def test_websocket_capture_large_message(temp_db): - """Test WebSocket capture with large messages requiring extended length.""" - # Create a large message - large_data = {"jsonrpc": "2.0", "method": "data", "params": {"value": "x" * 1000}, "id": 1} - frame = create_ws_text_frame(json.dumps(large_data)) - - packet = create_mock_packet("127.0.0.1", 12345, "127.0.0.1", 8765, frame) - packet_callback(packet) - - # Check database - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - cursor.execute("SELECT message FROM logs") - rows = cursor.fetchall() - conn.close() - - assert len(rows) == 1 - msg = json.loads(rows[0][0]) - assert msg["method"] == "data" - assert len(msg["params"]["value"]) == 1000 - - -def test_websocket_tcp_segmentation(temp_db): - """Test handling of WebSocket frames split across TCP packets.""" - # Clear buffers - ws_buffers.clear() - - # Create a message - msg = {"jsonrpc": "2.0", "method": "test", "id": 1} - frame = create_ws_text_frame(json.dumps(msg)) - - # Split frame into two packets - split_point = len(frame) // 2 - packet1 = create_mock_packet("127.0.0.1", 12345, "127.0.0.1", 8765, frame[:split_point]) - packet2 = create_mock_packet("127.0.0.1", 12345, "127.0.0.1", 8765, frame[split_point:]) - - # Process both packets - print(f"\nProcessing split frame: {len(frame)} bytes total") - print(f"Part 1: {len(frame[:split_point])} bytes") - print(f"Part 2: {len(frame[split_point:])} bytes") - - packet_callback(packet1) - packet_callback(packet2) - - # Check database - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - cursor.execute("SELECT message FROM logs") - rows = cursor.fetchall() - conn.close() - - # Should have reassembled into one message - assert len(rows) == 1 - msg = json.loads(rows[0][0]) - assert msg["method"] == "test" - - -def test_websocket_multiple_frames_in_packet(temp_db): - """Test multiple WebSocket frames in one TCP packet.""" - # Clear buffers - ws_buffers.clear() - - # Create multiple messages - msg1 = {"jsonrpc": "2.0", "method": "ping", "id": 1} - msg2 = {"jsonrpc": "2.0", "method": "ping", "id": 2} - - frame1 = create_ws_text_frame(json.dumps(msg1)) - frame2 = create_ws_text_frame(json.dumps(msg2)) - - # Combine frames into one packet - combined = frame1 + frame2 - packet = create_mock_packet("127.0.0.1", 12345, "127.0.0.1", 8765, combined) - packet_callback(packet) - - # Check database - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - cursor.execute("SELECT message FROM logs ORDER BY timestamp") - rows = cursor.fetchall() - conn.close() - - assert len(rows) == 2 - - msg1_db = json.loads(rows[0][0]) - msg2_db = json.loads(rows[1][0]) - - assert msg1_db["id"] == 1 - assert msg2_db["id"] == 2 - diff --git a/tests/test_websocket_simple.py b/tests/test_websocket_simple.py deleted file mode 100644 index bd7d71e..0000000 --- a/tests/test_websocket_simple.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Simple WebSocket parsing tests without full integration.""" - -import json - -from mcphawk.ws_reassembly import _parse_ws_frames, process_ws_packet, ws_buffers - - -class TestWebSocketParsing: - """Test WebSocket frame parsing without network integration.""" - - def test_parse_unmasked_text_frame(self): - """Test parsing unmasked text frame (server->client).""" - text = '{"jsonrpc":"2.0","method":"test","id":1}' - payload = text.encode('utf-8') - - # Build WebSocket frame - frame = bytearray() - frame.append(0x81) # FIN=1, opcode=1 (text) - frame.append(len(payload)) # No mask - frame.extend(payload) - - messages, consumed = _parse_ws_frames(bytes(frame)) - assert len(messages) == 1 - assert messages[0][1] == text # messages are (fin, text) tuples - assert consumed == len(frame) - - def test_parse_masked_text_frame(self): - """Test parsing masked text frame (client->server).""" - text = '{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}' - payload = text.encode('utf-8') - - # Build masked WebSocket frame - frame = bytearray() - frame.append(0x81) # FIN=1, opcode=1 (text) - frame.append(0x80 | len(payload)) # Masked bit + length - - # Add masking key - mask = b'\x12\x34\x56\x78' - frame.extend(mask) - - # Mask the payload - masked_payload = bytearray() - for i, byte in enumerate(payload): - masked_payload.append(byte ^ mask[i % 4]) - frame.extend(masked_payload) - - messages, consumed = _parse_ws_frames(bytes(frame)) - assert len(messages) == 1 - assert messages[0][1] == text # messages are (fin, text) tuples - assert consumed == len(frame) - - def test_parse_fragmented_frames(self): - """Test parsing fragmented WebSocket frames.""" - # First fragment - part1 = '{"jsonrpc":"2.0",' - frame1 = bytearray() - frame1.append(0x01) # FIN=0, opcode=1 (text) - frame1.append(len(part1)) - frame1.extend(part1.encode('utf-8')) - - # Final fragment - part2 = '"method":"test","id":1}' - frame2 = bytearray() - frame2.append(0x80) # FIN=1, opcode=0 (continuation) - frame2.append(len(part2)) - frame2.extend(part2.encode('utf-8')) - - messages, consumed = _parse_ws_frames(bytes(frame1 + frame2)) - assert len(messages) == 2 # Two fragments - assert messages[0][1] == part1 - assert messages[1][1] == part2 - - def test_process_ws_packet_filters_jsonrpc(self): - """Test that process_ws_packet filters for JSON-RPC messages.""" - # Clear any existing buffers - ws_buffers.clear() - - # JSON-RPC message - jsonrpc_text = '{"jsonrpc":"2.0","method":"test","id":1}' - jsonrpc_frame = bytearray() - jsonrpc_frame.append(0x81) - jsonrpc_frame.append(len(jsonrpc_text)) - jsonrpc_frame.extend(jsonrpc_text.encode('utf-8')) - - # Non-JSON-RPC message - other_text = '{"type":"hello","data":"world"}' - other_frame = bytearray() - other_frame.append(0x81) - other_frame.append(len(other_text)) - other_frame.extend(other_text.encode('utf-8')) - - # Process JSON-RPC frame - messages = process_ws_packet("127.0.0.1", 12345, "127.0.0.1", 8765, bytes(jsonrpc_frame)) - assert len(messages) == 1 - assert json.loads(messages[0])["jsonrpc"] == "2.0" - - # Process non-JSON-RPC frame - messages = process_ws_packet("127.0.0.1", 12345, "127.0.0.1", 8765, bytes(other_frame)) - assert len(messages) == 0 - - def test_process_ws_packet_skips_http_upgrade(self): - """Test that HTTP upgrade requests are skipped.""" - http_request = b'GET /ws HTTP/1.1\r\nUpgrade: websocket\r\n\r\n' - messages = process_ws_packet("127.0.0.1", 12345, "127.0.0.1", 8765, http_request) - assert len(messages) == 0 - - http_response = b'HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\n\r\n' - messages = process_ws_packet("127.0.0.1", 8765, "127.0.0.1", 12345, http_response) - assert len(messages) == 0 - - def test_multiple_frames_in_one_packet(self): - """Test parsing multiple frames in one TCP packet.""" - # First message - msg1 = '{"jsonrpc":"2.0","method":"ping","id":1}' - frame1 = bytearray() - frame1.append(0x81) - frame1.append(len(msg1)) - frame1.extend(msg1.encode('utf-8')) - - # Second message - msg2 = '{"jsonrpc":"2.0","result":"pong","id":1}' - frame2 = bytearray() - frame2.append(0x81) - frame2.append(len(msg2)) - frame2.extend(msg2.encode('utf-8')) - - # Parse both frames together - messages, consumed = _parse_ws_frames(bytes(frame1 + frame2)) - assert len(messages) == 2 - assert messages[0][1] == msg1 - assert messages[1][1] == msg2 diff --git a/tests/test_ws_reassembly.py b/tests/test_ws_reassembly.py deleted file mode 100644 index afa8e82..0000000 --- a/tests/test_ws_reassembly.py +++ /dev/null @@ -1,64 +0,0 @@ -from mcphawk.ws_reassembly import process_ws_packet - -# Connection identifiers for test -SRC_IP, SRC_PORT = "127.0.0.1", 11111 -DST_IP, DST_PORT = "127.0.0.1", 22222 - - -def build_ws_frame(message: bytes, fin: bool = True, opcode: int = 0x1) -> bytes: - """ - Build a single valid unmasked WebSocket frame (server->client style). - - Args: - message: Payload as bytes. - fin: Whether this is the final frame (FIN bit). - opcode: Frame opcode (0x1 = text, 0x0 = continuation). - - Returns: - Raw bytes representing the WebSocket frame. - """ - first_byte = (0x80 if fin else 0x00) | opcode - length = len(message) - - if length <= 125: - header = bytes([first_byte, length]) - elif length <= 65535: - header = bytes([first_byte, 126]) + length.to_bytes(2, "big") - else: - header = bytes([first_byte, 127]) + length.to_bytes(8, "big") - - return header + message - - -def test_process_ws_packet_complete(): - """Ensure a full single WebSocket frame is reassembled correctly.""" - message = b'{"jsonrpc":"2.0","method":"ping"}' - frame = build_ws_frame(message) - - msgs = process_ws_packet(SRC_IP, SRC_PORT, DST_IP, DST_PORT, frame) - assert len(msgs) == 1 - assert "ping" in msgs[0] - - -def test_process_ws_packet_fragmented(): - """Test that fragmented frames are handled (currently not buffered).""" - # Our simplified implementation doesn't buffer fragmented frames - # This is acceptable for most real-world MCP traffic which uses small messages - full_msg = b'{"jsonrpc":"2.0","method":"pong"}' - mid = len(full_msg) // 2 - - frame1 = build_ws_frame(full_msg[:mid], fin=False, opcode=0x1) - frame2 = build_ws_frame(full_msg[mid:], fin=True, opcode=0x0) - - # Each fragment is processed independently - msgs = process_ws_packet(SRC_IP, SRC_PORT, DST_IP, DST_PORT, frame1) - assert msgs == [] # Partial JSON won't be captured - - msgs = process_ws_packet(SRC_IP, SRC_PORT, DST_IP, DST_PORT, frame2) - assert msgs == [] # Continuation frame alone won't be captured - - # For complete capture, send as single frame - complete_frame = build_ws_frame(full_msg) - msgs = process_ws_packet(SRC_IP, SRC_PORT, DST_IP, DST_PORT, complete_frame) - assert len(msgs) == 1 - assert "pong" in msgs[0] diff --git a/tests/test_ws_tcp_classification.py b/tests/test_ws_tcp_classification.py deleted file mode 100644 index ec33d97..0000000 --- a/tests/test_ws_tcp_classification.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Test that WebSocket and TCP traffic are correctly classified.""" -import json -import time - -import pytest -from scapy.layers.inet import IP, TCP -from scapy.packet import Raw - -from mcphawk.logger import clear_logs, fetch_logs, init_db, set_db_path -from mcphawk.sniffer import packet_callback - - -@pytest.fixture -def test_db(tmp_path): - """Create a test database.""" - db_path = tmp_path / "test_ws_tcp_class.db" - set_db_path(str(db_path)) - init_db() - yield db_path - clear_logs() - - -def test_websocket_http_upgrade_not_misclassified(test_db): - """Test that WebSocket HTTP upgrade is not classified as TCP JSON-RPC.""" - # Create HTTP upgrade packet - http_upgrade = ( - b"GET / HTTP/1.1\r\n" - b"Host: localhost:8765\r\n" - b"Upgrade: websocket\r\n" - b"Connection: Upgrade\r\n" - b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" - b"Sec-WebSocket-Version: 13\r\n\r\n" - ) - - pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=50000, dport=8765) / Raw(load=http_upgrade) - packet_callback(pkt) - - time.sleep(0.1) - - # Should not create any log entries (HTTP upgrade is not JSON-RPC) - logs = fetch_logs(limit=10) - assert len(logs) == 0 - - -def test_websocket_empty_frame_not_misclassified(test_db): - """Test that incomplete WebSocket frames are not classified as TCP.""" - # Send part of a WebSocket frame (incomplete) - partial_frame = bytes([0x81, 0x7e, 0x00]) # Text frame header, but incomplete - - pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=50001, dport=8765) / Raw(load=partial_frame) - packet_callback(pkt) - - time.sleep(0.1) - - # Should not create any entries (incomplete frame, no JSON-RPC) - logs = fetch_logs(limit=10) - assert len(logs) == 0 - - -def test_websocket_complete_frame_classified_correctly(test_db): - """Test that complete WebSocket frames are classified as WS.""" - # Create a complete WebSocket text frame with JSON-RPC - json_rpc = json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}) - frame = bytes([0x81, len(json_rpc)]) + json_rpc.encode() - - pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=50002, dport=8765) / Raw(load=frame) - packet_callback(pkt) - - time.sleep(0.1) - - logs = fetch_logs(limit=10) - assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/WS" - assert "test" in logs[0]["message"] - - -def test_tcp_jsonrpc_classified_correctly(test_db): - """Test that raw TCP JSON-RPC is classified as TCP.""" - json_rpc = json.dumps({"jsonrpc": "2.0", "method": "tcp_test", "id": 1}) - - pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=12345, dport=50003) / Raw(load=json_rpc.encode()) - packet_callback(pkt) - - time.sleep(0.1) - - logs = fetch_logs(limit=10) - assert len(logs) == 1 - assert logs[0]["traffic_type"] == "TCP/Direct" - assert "tcp_test" in logs[0]["message"] - - -def test_mixed_traffic_correct_classification(test_db): - """Test that mixed TCP and WebSocket traffic is correctly classified.""" - # Send TCP JSON-RPC - tcp_msg = json.dumps({"jsonrpc": "2.0", "method": "tcp_method", "id": 1}) - tcp_pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=12345, dport=50004) / Raw(load=tcp_msg.encode()) - packet_callback(tcp_pkt) - - # Send WebSocket frame - ws_msg = json.dumps({"jsonrpc": "2.0", "method": "ws_method", "id": 2}) - ws_frame = bytes([0x81, len(ws_msg)]) + ws_msg.encode() - ws_pkt = IP(src="127.0.0.1", dst="127.0.0.1") / TCP(sport=50005, dport=8765) / Raw(load=ws_frame) - packet_callback(ws_pkt) - - time.sleep(0.1) - - logs = fetch_logs(limit=10) - assert len(logs) == 2 - - # Check each log - tcp_log = next((log for log in logs if "tcp_method" in log["message"]), None) - ws_log = next((log for log in logs if "ws_method" in log["message"]), None) - - assert tcp_log is not None - assert tcp_log["traffic_type"] == "TCP/Direct" - - assert ws_log is not None - assert ws_log["traffic_type"] == "TCP/WS"