From b6b5645a2fcd22d5707067df6a92c5131dcca21f Mon Sep 17 00:00:00 2001 From: tech4242 <5933291+tech4242@users.noreply.github.com> Date: Mon, 4 Aug 2025 21:15:33 +0200 Subject: [PATCH 1/4] fix: centralize db connections --- mcphawk/logger.py | 251 ++++++++++++++++++++----------------- tests/test_logger.py | 209 ++++++++++++++++++++---------- tests/test_sniffer.py | 14 +-- tests/test_traffic_type.py | 49 ++++---- tests/test_web.py | 9 +- 5 files changed, 315 insertions(+), 217 deletions(-) diff --git a/mcphawk/logger.py b/mcphawk/logger.py index b7a27bb..c8679d0 100644 --- a/mcphawk/logger.py +++ b/mcphawk/logger.py @@ -1,5 +1,7 @@ import logging import sqlite3 +from collections.abc import Generator +from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -16,6 +18,34 @@ _db_initialized = False +@contextmanager +def get_db_connection(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]: + """ + Context manager for SQLite database connections. + + Args: + db_path: Optional path to database. Uses DB_PATH if not provided. + + Yields: + sqlite3.Connection: Database connection with row_factory set. + + Example: + with get_db_connection() as conn: + cur = conn.cursor() + cur.execute("SELECT * FROM logs") + """ + path = db_path or DB_PATH + if not path: + raise ValueError("No database path provided") + + conn = sqlite3.connect(path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + def init_db() -> None: """ Initialize the SQLite database and ensure the logs table exists. @@ -24,27 +54,26 @@ def init_db() -> None: if not DB_PATH or not str(DB_PATH).strip(): raise ValueError("DB_PATH is not set or is empty") - conn = sqlite3.connect(DB_PATH) - cur = conn.cursor() - cur.execute( - """ - CREATE TABLE IF NOT EXISTS logs ( - log_id TEXT PRIMARY KEY, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - src_ip TEXT, - dst_ip TEXT, - src_port INTEGER, - dst_port INTEGER, - direction TEXT CHECK(direction IN ('incoming', 'outgoing', 'unknown')), - message TEXT, - transport_type TEXT, - metadata TEXT, - pid INTEGER + with get_db_connection() as conn: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS logs ( + log_id TEXT PRIMARY KEY, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + src_ip TEXT, + dst_ip TEXT, + src_port INTEGER, + dst_port INTEGER, + direction TEXT CHECK(direction IN ('incoming', 'outgoing', 'unknown')), + message TEXT, + transport_type TEXT, + metadata TEXT, + pid INTEGER + ) + """ ) - """ - ) - conn.commit() - conn.close() + conn.commit() def log_message(entry: dict[str, Any]) -> None: @@ -63,35 +92,35 @@ def log_message(entry: dict[str, Any]) -> None: message (str) transport_type (str): 'streamable_http', 'http_sse', 'stdio', or 'unknown' (optional, defaults to 'unknown') metadata (str): JSON string with additional metadata (optional) + pid (int): Process ID for stdio transport (optional) """ timestamp = entry.get("timestamp", datetime.now(tz=timezone.utc)) log_id = entry.get("log_id") if not log_id: raise ValueError("log_id is required") - conn = sqlite3.connect(DB_PATH) - cur = conn.cursor() - cur.execute( - """ - INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - log_id, - timestamp.isoformat(), - entry.get("src_ip"), - entry.get("dst_ip"), - entry.get("src_port"), - entry.get("dst_port"), - entry.get("direction", "unknown"), - entry.get("message"), - entry.get("transport_type", "unknown"), - entry.get("metadata"), - entry.get("pid"), - ), - ) - conn.commit() - conn.close() + with get_db_connection() as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + log_id, + timestamp.isoformat(), + entry.get("src_ip"), + entry.get("dst_ip"), + entry.get("src_port"), + entry.get("dst_port"), + entry.get("direction", "unknown"), + entry.get("message"), + entry.get("transport_type", "unknown"), + entry.get("metadata"), + entry.get("pid"), + ), + ) + conn.commit() def fetch_logs(limit: int = 100) -> list[dict[str, Any]]: @@ -113,20 +142,18 @@ def fetch_logs(limit: int = 100) -> list[dict[str, Any]]: logger.warning(f"Database file not found at {current_path}") return [] - conn = sqlite3.connect(current_path) - conn.row_factory = sqlite3.Row - cur = conn.cursor() - cur.execute( - """ - SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid - FROM logs - ORDER BY timestamp DESC - LIMIT ? - """, - (limit,), - ) - rows = cur.fetchall() - conn.close() + with get_db_connection(current_path) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid + FROM logs + ORDER BY timestamp DESC + LIMIT ? + """, + (limit,), + ) + rows = cur.fetchall() return [ { @@ -165,11 +192,10 @@ def clear_logs() -> None: Clear all logs from the database. Mainly used in tests. """ - conn = sqlite3.connect(DB_PATH) - cur = conn.cursor() - cur.execute("DELETE FROM logs;") - conn.commit() - conn.close() + with get_db_connection() as conn: + cur = conn.cursor() + cur.execute("DELETE FROM logs;") + conn.commit() def get_log_by_id(log_id: str) -> dict[str, Any] | None: @@ -186,19 +212,17 @@ def get_log_by_id(log_id: str) -> dict[str, Any] | None: if not current_path.exists(): return None - conn = sqlite3.connect(current_path) - conn.row_factory = sqlite3.Row - cur = conn.cursor() - cur.execute( - """ - SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata - FROM logs - WHERE log_id = ? - """, - (log_id,), - ) - row = cur.fetchone() - conn.close() + with get_db_connection(current_path) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid + FROM logs + WHERE log_id = ? + """, + (log_id,), + ) + row = cur.fetchone() if not row: return None @@ -214,6 +238,7 @@ def get_log_by_id(log_id: str) -> dict[str, Any] | None: "message": row["message"], "transport_type": row["transport_type"] if row["transport_type"] is not None else "unknown", "metadata": row["metadata"], + "pid": row["pid"], } @@ -232,20 +257,18 @@ def fetch_logs_with_offset(limit: int = 100, offset: int = 0) -> list[dict[str, if not current_path.exists(): return [] - conn = sqlite3.connect(current_path) - conn.row_factory = sqlite3.Row - cur = conn.cursor() - cur.execute( - """ - SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid - FROM logs - ORDER BY log_id DESC - LIMIT ? OFFSET ? - """, - (limit, offset), - ) - rows = cur.fetchall() - conn.close() + with get_db_connection(current_path) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message, transport_type, metadata, pid + FROM logs + ORDER BY log_id DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ) + rows = cur.fetchall() return [ { @@ -283,27 +306,25 @@ def search_logs(search_term: str = "", message_type: str | None = None, if not current_path.exists(): return [] - conn = sqlite3.connect(current_path) - conn.row_factory = sqlite3.Row - cur = conn.cursor() + with get_db_connection(current_path) as conn: + cur = conn.cursor() - query = "SELECT * FROM logs WHERE 1=1" - params = [] + query = "SELECT * FROM logs WHERE 1=1" + params = [] - if search_term: - query += " AND message LIKE ?" - params.append(f"%{search_term}%") + if search_term: + query += " AND message LIKE ?" + params.append(f"%{search_term}%") - if transport_type: - query += " AND transport_type = ?" - params.append(transport_type) + if transport_type: + query += " AND transport_type = ?" + params.append(transport_type) - query += " ORDER BY log_id DESC LIMIT ?" - params.append(limit) + query += " ORDER BY log_id DESC LIMIT ?" + params.append(limit) - cur.execute(query, params) - rows = cur.fetchall() - conn.close() + cur.execute(query, params) + rows = cur.fetchall() # Filter by message type if specified results = [] @@ -351,12 +372,12 @@ def get_traffic_stats() -> dict[str, Any]: "by_transport_type": {} } - conn = sqlite3.connect(current_path) - cur = conn.cursor() + with get_db_connection(current_path) as conn: + cur = conn.cursor() - # Get all messages for analysis - cur.execute("SELECT message, transport_type FROM logs") - logs = cur.fetchall() + # Get all messages for analysis + cur.execute("SELECT message, transport_type FROM logs") + logs = cur.fetchall() stats = { "total_logs": len(logs), @@ -392,7 +413,6 @@ def get_traffic_stats() -> dict[str, Any]: if transport_type: stats["by_transport_type"][transport_type] = stats["by_transport_type"].get(transport_type, 0) + 1 - conn.close() return stats @@ -407,11 +427,10 @@ def get_unique_methods() -> list[str]: if not current_path.exists(): return [] - conn = sqlite3.connect(current_path) - cur = conn.cursor() - cur.execute("SELECT message FROM logs") - logs = cur.fetchall() - conn.close() + with get_db_connection(current_path) as conn: + cur = conn.cursor() + cur.execute("SELECT message FROM logs") + logs = cur.fetchall() methods = set() for (message,) in logs: diff --git a/tests/test_logger.py b/tests/test_logger.py index 883860e..9b2a266 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -7,7 +7,88 @@ import pytest -from mcphawk.logger import fetch_logs, init_db, log_message, set_db_path +from mcphawk.logger import ( + fetch_logs, + get_db_connection, + init_db, + log_message, + set_db_path, +) + + +class TestDBConnection: + """Test the database connection context manager.""" + + @pytest.fixture + def temp_db(self): + """Create a temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f: + temp_path = Path(f.name) + + yield temp_path + + # Cleanup + temp_path.unlink(missing_ok=True) + + def test_context_manager_basic(self, temp_db): + """Test basic context manager functionality.""" + # Create a simple table for testing + with get_db_connection(temp_db) as conn: + cur = conn.cursor() + cur.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, data TEXT)") + conn.commit() + + # Verify connection is closed by trying to use it + with pytest.raises(sqlite3.ProgrammingError): + cur.execute("SELECT * FROM test") + + def test_context_manager_with_error(self, temp_db): + """Test context manager closes connection even on error.""" + # Create a simple table + with get_db_connection(temp_db) as conn: + cur = conn.cursor() + cur.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") + conn.commit() + + # Force an error inside context + try: + with get_db_connection(temp_db) as conn: + cur = conn.cursor() + # This will raise an error (table already exists) + cur.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") + except sqlite3.OperationalError: + pass + + # Connection should still be closed + with pytest.raises(sqlite3.ProgrammingError): + cur.execute("SELECT * FROM test") + + def test_context_manager_row_factory(self, temp_db): + """Test that row_factory is set correctly.""" + # Create and populate a table + with get_db_connection(temp_db) as conn: + cur = conn.cursor() + cur.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + cur.execute("INSERT INTO test (name) VALUES (?)", ("test_name",)) + conn.commit() + + # Verify row_factory allows dict-like access + with get_db_connection(temp_db) as conn: + cur = conn.cursor() + cur.execute("SELECT * FROM test") + row = cur.fetchone() + assert row["name"] == "test_name" + assert row["id"] == 1 + + def test_context_manager_uses_default_path(self): + """Test that context manager uses default DB_PATH when None provided.""" + # When None is provided, it should use the default DB_PATH + with get_db_connection(None) as conn: + # Should connect successfully using default path + cur = conn.cursor() + cur.execute("SELECT 1") + result = cur.fetchone() + assert result[0] == 1 class TestLogger: @@ -34,33 +115,31 @@ def test_init_db(self, temp_db): assert Path(temp_db).exists() # Check schema - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - - # Get table info - cursor.execute("PRAGMA table_info(logs)") - columns = {col[1]: col[2] for col in cursor.fetchall()} - - # Check all expected columns exist - expected_columns = { - "log_id": "TEXT", - "timestamp": "DATETIME", - "src_ip": "TEXT", - "dst_ip": "TEXT", - "src_port": "INTEGER", - "dst_port": "INTEGER", - "direction": "TEXT", - "message": "TEXT", - "transport_type": "TEXT", - "metadata": "TEXT", - "pid": "INTEGER" - } - - for col, dtype in expected_columns.items(): - assert col in columns - assert columns[col] == dtype + with get_db_connection(Path(temp_db)) as conn: + cursor = conn.cursor() + + # Get table info + cursor.execute("PRAGMA table_info(logs)") + columns = {col[1]: col[2] for col in cursor.fetchall()} + + # Check all expected columns exist + expected_columns = { + "log_id": "TEXT", + "timestamp": "DATETIME", + "src_ip": "TEXT", + "dst_ip": "TEXT", + "src_port": "INTEGER", + "dst_port": "INTEGER", + "direction": "TEXT", + "message": "TEXT", + "transport_type": "TEXT", + "metadata": "TEXT", + "pid": "INTEGER" + } - conn.close() + for col, dtype in expected_columns.items(): + assert col in columns + assert columns[col] == dtype def test_log_message_basic(self, temp_db): """Test basic message logging.""" @@ -137,17 +216,15 @@ def temp_db(self): def test_schema_includes_pid(self, temp_db): """Test that database schema includes PID column.""" - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() + with get_db_connection(Path(temp_db)) as conn: + cursor = conn.cursor() - # Get table info - cursor.execute("PRAGMA table_info(logs)") - columns = {col[1]: col[2] for col in cursor.fetchall()} + # Get table info + cursor.execute("PRAGMA table_info(logs)") + columns = {col[1]: col[2] for col in cursor.fetchall()} - assert "pid" in columns - assert columns["pid"] == "INTEGER" - - conn.close() + assert "pid" in columns + assert columns["pid"] == "INTEGER" def test_log_message_with_pid(self, temp_db): """Test logging message with PID.""" @@ -269,44 +346,40 @@ def test_query_by_pid(self, temp_db): log_message(entry) # Direct SQL query to filter by PID - conn = sqlite3.connect(temp_db) - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - cursor.execute("SELECT * FROM logs WHERE pid = ?", (12345,)) - rows = cursor.fetchall() + with get_db_connection(Path(temp_db)) as conn: + cursor = conn.cursor() - assert len(rows) == 2 - for row in rows: - assert row["pid"] == 12345 + cursor.execute("SELECT * FROM logs WHERE pid = ?", (12345,)) + rows = cursor.fetchall() - conn.close() + assert len(rows) == 2 + for row in rows: + assert row["pid"] == 12345 def test_backward_compatibility(self, temp_db): """Test that old logs without PID field still work.""" # Directly insert an old-style log without PID - conn = sqlite3.connect(temp_db) - cursor = conn.cursor() - - # Insert without specifying PID (should be NULL) - cursor.execute(""" - INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, - direction, message, transport_type, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - "old-style-1", - datetime.now(tz=timezone.utc).isoformat(), - "127.0.0.1", - "127.0.0.1", - 3000, - 3001, - "outgoing", - '{"jsonrpc":"2.0","method":"test","id":1}', - "streamable_http", - '{}' - )) - conn.commit() - conn.close() + with get_db_connection(Path(temp_db)) as conn: + cursor = conn.cursor() + + # Insert without specifying PID (should be NULL) + cursor.execute(""" + INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, + direction, message, transport_type, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + "old-style-1", + datetime.now(tz=timezone.utc).isoformat(), + "127.0.0.1", + "127.0.0.1", + 3000, + 3001, + "outgoing", + '{"jsonrpc":"2.0","method":"test","id":1}', + "streamable_http", + '{}' + )) + conn.commit() # Fetch logs should work logs = fetch_logs(1) diff --git a/tests/test_sniffer.py b/tests/test_sniffer.py index 980cf9b..ca8f873 100644 --- a/tests/test_sniffer.py +++ b/tests/test_sniffer.py @@ -1,6 +1,5 @@ import os import socket -import sqlite3 import threading import time from unittest.mock import MagicMock, patch @@ -11,7 +10,7 @@ from scapy.layers.inet6 import IPv6 from scapy.packet import Raw -from mcphawk.logger import init_db, set_db_path +from mcphawk.logger import get_db_connection, init_db, set_db_path from mcphawk.sniffer import packet_callback, start_sniffer # --- TEST DB PATH --- @@ -84,11 +83,12 @@ def test_packet_callback(clean_db, dummy_server): ) packet_callback(pkt) - conn = sqlite3.connect(TEST_DB) - cur = conn.cursor() - cur.execute("SELECT message FROM logs ORDER BY log_id DESC LIMIT 1;") - logged_msg = cur.fetchone()[0] - conn.close() + from pathlib import Path + + with get_db_connection(Path(TEST_DB)) as conn: + cur = conn.cursor() + cur.execute("SELECT message FROM logs ORDER BY log_id DESC LIMIT 1;") + logged_msg = cur.fetchone()[0] assert "weather" in logged_msg assert "Berlin" in logged_msg diff --git a/tests/test_traffic_type.py b/tests/test_traffic_type.py index b74b299..564fab8 100644 --- a/tests/test_traffic_type.py +++ b/tests/test_traffic_type.py @@ -5,7 +5,14 @@ import pytest -from mcphawk.logger import clear_logs, fetch_logs, init_db, log_message, set_db_path +from mcphawk.logger import ( + clear_logs, + fetch_logs, + get_db_connection, + init_db, + log_message, + set_db_path, +) @pytest.fixture @@ -64,29 +71,27 @@ def test_unknown_transport_type(test_db): def test_legacy_entries_without_transport_type(test_db): """Test that we can handle legacy entries without transport_type column.""" # This tests the backward compatibility in fetch_logs - import sqlite3 # Insert a row without transport_type using direct SQL - conn = sqlite3.connect(str(test_db)) - cur = conn.cursor() - cur.execute( - """ - INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(uuid.uuid4()), - datetime.now(tz=timezone.utc).isoformat(), - "127.0.0.1", - "127.0.0.1", - 12345, - 54321, - "outgoing", - json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}), - ), - ) - conn.commit() - conn.close() + with get_db_connection(test_db) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO logs (log_id, timestamp, src_ip, dst_ip, src_port, dst_port, direction, message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + str(uuid.uuid4()), + datetime.now(tz=timezone.utc).isoformat(), + "127.0.0.1", + "127.0.0.1", + 12345, + 54321, + "outgoing", + json.dumps({"jsonrpc": "2.0", "method": "test", "id": 1}), + ), + ) + conn.commit() logs = fetch_logs(limit=1) assert len(logs) == 1 diff --git a/tests/test_web.py b/tests/test_web.py index f611812..5756280 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -47,12 +47,13 @@ def clean_db(setup_test_db): """ Ensure the database is cleared before each test. """ - import sqlite3 + from pathlib import Path - conn = sqlite3.connect(TEST_DB_PATH) - with conn: + from mcphawk.logger import get_db_connection + + with get_db_connection(Path(TEST_DB_PATH)) as conn: conn.execute("DELETE FROM logs;") - conn.close() + conn.commit() yield From 57fca8ccebf7bc47075cdfc0b0040e06fc86c1f3 Mon Sep 17 00:00:00 2001 From: tech4242 <5933291+tech4242@users.noreply.github.com> Date: Mon, 4 Aug 2025 21:29:03 +0200 Subject: [PATCH 2/4] fix: refactor tests structure --- Makefile | 23 ++++++++++++++++++- pyproject.toml | 7 ++++++ tests/integration/__init__.py | 0 tests/integration/cli/__init__.py | 0 tests/{ => integration/cli}/test_cli.py | 0 tests/{ => integration/cli}/test_wrapper.py | 0 tests/integration/db/__init__.py | 0 tests/{ => integration/db}/test_logger.py | 0 .../{ => integration/db}/test_traffic_type.py | 0 tests/integration/mcp/__init__.py | 0 .../mcp}/test_mcp_http_simple.py | 0 .../{ => integration/mcp}/test_mcp_server.py | 0 .../mcp}/test_mcp_stdio_integration.py | 0 tests/integration/network/__init__.py | 0 .../network}/test_ipv4_ipv6_capture.py | 0 .../{ => integration/network}/test_sniffer.py | 0 .../network}/test_sniffer_traffic_type.py | 0 tests/integration/web/__init__.py | 0 .../{ => integration/web}/test_broadcaster.py | 0 tests/{ => integration/web}/test_web.py | 0 .../{ => integration/web}/test_web_server.py | 0 tests/unit/__init__.py | 0 tests/{ => unit}/test_tcp_reassembly.py | 0 tests/{ => unit}/test_transport_detector.py | 0 tests/{ => unit}/test_utils.py | 0 25 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/cli/__init__.py rename tests/{ => integration/cli}/test_cli.py (100%) rename tests/{ => integration/cli}/test_wrapper.py (100%) create mode 100644 tests/integration/db/__init__.py rename tests/{ => integration/db}/test_logger.py (100%) rename tests/{ => integration/db}/test_traffic_type.py (100%) create mode 100644 tests/integration/mcp/__init__.py rename tests/{ => integration/mcp}/test_mcp_http_simple.py (100%) rename tests/{ => integration/mcp}/test_mcp_server.py (100%) rename tests/{ => integration/mcp}/test_mcp_stdio_integration.py (100%) create mode 100644 tests/integration/network/__init__.py rename tests/{ => integration/network}/test_ipv4_ipv6_capture.py (100%) rename tests/{ => integration/network}/test_sniffer.py (100%) rename tests/{ => integration/network}/test_sniffer_traffic_type.py (100%) create mode 100644 tests/integration/web/__init__.py rename tests/{ => integration/web}/test_broadcaster.py (100%) rename tests/{ => integration/web}/test_web.py (100%) rename tests/{ => integration/web}/test_web_server.py (100%) create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/test_tcp_reassembly.py (100%) rename tests/{ => unit}/test_transport_detector.py (100%) rename tests/{ => unit}/test_utils.py (100%) diff --git a/Makefile b/Makefile index d14399c..bb8e260 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: install install-frontend build build-frontend dev dev-backend dev-frontend test test-watch coverage coverage-report lint format format-unsafe clean +.PHONY: install install-frontend build build-frontend dev dev-backend dev-frontend test test-unit test-integration test-db test-network test-cli test-web test-mcp test-watch coverage coverage-report lint format format-unsafe clean # Install all dependencies install: install-backend install-frontend @@ -31,6 +31,27 @@ dev-frontend: test: python3 -m pytest -v +test-unit: + python3 -m pytest tests/unit -v + +test-integration: + python3 -m pytest tests/integration -v + +test-db: + python3 -m pytest tests/integration/db -v + +test-network: + python3 -m pytest tests/integration/network -v + +test-cli: + python3 -m pytest tests/integration/cli -v + +test-web: + python3 -m pytest tests/integration/web -v + +test-mcp: + python3 -m pytest tests/integration/mcp -v + test-watch: python3 -m pytest -v --watch diff --git a/pyproject.toml b/pyproject.toml index a2e557b..7e161ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,9 +86,16 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # Ignore assert statements in tests "test_*.py" = ["S101", "E712"] "tests/*.py" = ["S101", "E712"] +"tests/**/*.py" = ["S101", "E712"] # Ignore Typer argument pattern in CLI "mcphawk/cli.py" = ["B008"] +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + [tool.coverage.run] omit = [ "mcphawk/models.py", diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/cli/__init__.py b/tests/integration/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_cli.py b/tests/integration/cli/test_cli.py similarity index 100% rename from tests/test_cli.py rename to tests/integration/cli/test_cli.py diff --git a/tests/test_wrapper.py b/tests/integration/cli/test_wrapper.py similarity index 100% rename from tests/test_wrapper.py rename to tests/integration/cli/test_wrapper.py diff --git a/tests/integration/db/__init__.py b/tests/integration/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_logger.py b/tests/integration/db/test_logger.py similarity index 100% rename from tests/test_logger.py rename to tests/integration/db/test_logger.py diff --git a/tests/test_traffic_type.py b/tests/integration/db/test_traffic_type.py similarity index 100% rename from tests/test_traffic_type.py rename to tests/integration/db/test_traffic_type.py diff --git a/tests/integration/mcp/__init__.py b/tests/integration/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_mcp_http_simple.py b/tests/integration/mcp/test_mcp_http_simple.py similarity index 100% rename from tests/test_mcp_http_simple.py rename to tests/integration/mcp/test_mcp_http_simple.py diff --git a/tests/test_mcp_server.py b/tests/integration/mcp/test_mcp_server.py similarity index 100% rename from tests/test_mcp_server.py rename to tests/integration/mcp/test_mcp_server.py diff --git a/tests/test_mcp_stdio_integration.py b/tests/integration/mcp/test_mcp_stdio_integration.py similarity index 100% rename from tests/test_mcp_stdio_integration.py rename to tests/integration/mcp/test_mcp_stdio_integration.py diff --git a/tests/integration/network/__init__.py b/tests/integration/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_ipv4_ipv6_capture.py b/tests/integration/network/test_ipv4_ipv6_capture.py similarity index 100% rename from tests/test_ipv4_ipv6_capture.py rename to tests/integration/network/test_ipv4_ipv6_capture.py diff --git a/tests/test_sniffer.py b/tests/integration/network/test_sniffer.py similarity index 100% rename from tests/test_sniffer.py rename to tests/integration/network/test_sniffer.py diff --git a/tests/test_sniffer_traffic_type.py b/tests/integration/network/test_sniffer_traffic_type.py similarity index 100% rename from tests/test_sniffer_traffic_type.py rename to tests/integration/network/test_sniffer_traffic_type.py diff --git a/tests/integration/web/__init__.py b/tests/integration/web/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_broadcaster.py b/tests/integration/web/test_broadcaster.py similarity index 100% rename from tests/test_broadcaster.py rename to tests/integration/web/test_broadcaster.py diff --git a/tests/test_web.py b/tests/integration/web/test_web.py similarity index 100% rename from tests/test_web.py rename to tests/integration/web/test_web.py diff --git a/tests/test_web_server.py b/tests/integration/web/test_web_server.py similarity index 100% rename from tests/test_web_server.py rename to tests/integration/web/test_web_server.py diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tcp_reassembly.py b/tests/unit/test_tcp_reassembly.py similarity index 100% rename from tests/test_tcp_reassembly.py rename to tests/unit/test_tcp_reassembly.py diff --git a/tests/test_transport_detector.py b/tests/unit/test_transport_detector.py similarity index 100% rename from tests/test_transport_detector.py rename to tests/unit/test_transport_detector.py diff --git a/tests/test_utils.py b/tests/unit/test_utils.py similarity index 100% rename from tests/test_utils.py rename to tests/unit/test_utils.py From 84c1ee634c11c2f670aea2513afd3ef204fa293d Mon Sep 17 00:00:00 2001 From: tech4242 <5933291+tech4242@users.noreply.github.com> Date: Mon, 4 Aug 2025 21:34:03 +0200 Subject: [PATCH 3/4] fix: remove overrides --- .github/workflows/ci.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 55e9adf..dc6c3e1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,7 +59,4 @@ jobs: flags: unittests name: codecov-${{ matrix.python-version }} token: ${{ secrets.CODECOV_TOKEN }} - override_commit: ${{ github.event.pull_request.head.sha }} - override_pr: ${{ github.event.number }} - override_branch: ${{ github.head_ref }} verbose: true \ No newline at end of file From d3877d673c950d686a27230f1c49cc8316a171fc Mon Sep 17 00:00:00 2001 From: tech4242 <5933291+tech4242@users.noreply.github.com> Date: Mon, 4 Aug 2025 22:02:31 +0200 Subject: [PATCH 4/4] fix: remove overrides --- codecov.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/codecov.yml b/codecov.yml index be76128..34e6559 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,6 +4,10 @@ coverage: default: target: 80 threshold: 5 + patch: + default: + target: 80 + threshold: 5 comment: layout: "reach, diff, flags, files"