diff --git a/backend/secuscan/config.py b/backend/secuscan/config.py index a7deae68..57115c01 100644 --- a/backend/secuscan/config.py +++ b/backend/secuscan/config.py @@ -81,6 +81,8 @@ class Settings(BaseSettings): sandbox_timeout: int = 600 # seconds sandbox_cpu_quota: float = 0.5 sandbox_memory_mb: int = 512 + sandbox_max_output_bytes: int = 5_242_880 # 5 MB + sandbox_allow_network: bool = True # Task-start payload limits (tunable via env vars) task_start_max_body_bytes: int = 64_000 # 64 KB total JSON body diff --git a/backend/secuscan/executor.py b/backend/secuscan/executor.py index 01a68020..c92c2cbb 100644 --- a/backend/secuscan/executor.py +++ b/backend/secuscan/executor.py @@ -18,8 +18,9 @@ from .config import settings from .database import get_db from .plugins import get_plugin_manager -from .models import TaskStatus, ScanPhase +from .models import TaskStatus, ScanPhase, SandboxConfig from .ratelimit import concurrent_limiter +from .sandbox_executor import sandbox_execute from .risk_scoring import compute_risk_score, compute_risk_factors @@ -354,28 +355,33 @@ async def execute_task(self, task_id: str): await self._broadcast(task_id, "status", TaskStatus.RUNNING.value) await self._broadcast_phase(task_id, ScanPhase.RUNNING_COMMAND.value) - # Execute command start_time = time.time() - output, exit_code = await self._execute_command( + output, exit_code, violation_reason = await self._execute_command( command, task_id, timeout=self._resolve_execution_timeout(inputs), ) duration = time.time() - start_time - # Save raw output raw_path = Path(settings.raw_output_dir) / f"{task_id}.txt" output = redact(output) with open(raw_path, 'w') as f: f.write(output) - # Some CLI tools use non-zero exit codes for "no result" states while still - # producing a complete, parseable report. Let plugin metadata opt into that. - final_status, error_message = self._classify_command_result( - plugin=plugin, - output=output, - exit_code=exit_code, - ) + if violation_reason: + status_map = { + "timeout": TaskStatus.TERMINATED_TIMEOUT.value, + "memory_limit": TaskStatus.TERMINATED_MEMORY.value, + "output_limit": TaskStatus.TERMINATED_OUTPUT.value, + } + final_status = status_map.get(violation_reason, TaskStatus.FAILED.value) + error_message = f"Sandbox violation: {violation_reason}" + else: + final_status, error_message = self._classify_command_result( + plugin=plugin, + output=output, + exit_code=exit_code, + ) await db.execute( """ @@ -499,63 +505,36 @@ async def _execute_command( self, command: list, task_id: str, - timeout: int = 600 + timeout: int = 600, ) -> tuple: - """ - Execute command in subprocess and stream output. + if timeout is None: + timeout = settings.sandbox_timeout + config = SandboxConfig( + timeout_seconds=0, + max_memory_mb=settings.sandbox_memory_mb, + max_output_bytes=settings.sandbox_max_output_bytes, + allow_network=settings.sandbox_allow_network, + ) - Args: - command: Command as list - task_id: Task identifier for logging - timeout: Execution timeout in seconds + async def _on_chunk(data: bytes, stream_name: str): + text = data.decode("utf-8", errors="replace") + await self._broadcast(task_id, "output", text) - Returns: - Tuple of (output, exit_code) - """ try: - process = await asyncio.create_subprocess_exec( - *command, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT + stdout, stderr, exit_code, violation_reason = await asyncio.wait_for( + sandbox_execute(command, config, broadcast_callback=_on_chunk), + timeout=timeout, ) - - output_lines = [] - - async def read_stream(): - stdout = process.stdout - if stdout is None: - return - - while not stdout.at_eof(): - line = await stdout.readline() - if line: - decoded_line = line.decode('utf-8', errors='replace') - output_lines.append(decoded_line) - await self._broadcast(task_id, "output", decoded_line) - - try: - await asyncio.wait_for(read_stream(), timeout=timeout) - await process.wait() - return "".join(output_lines), process.returncode if process.returncode is not None else -1 - - except asyncio.TimeoutError: - process.kill() - await process.wait() - return "".join(output_lines) + "\nTask timed out", -1 - - except asyncio.CancelledError: - # Handle task cancellation by killing the subprocess - logger.warning(f"Task {task_id} cancelled. Killing process {process.pid}") - try: - process.kill() - await process.wait() - except Exception as e: - logger.error(f"Error killing process for cancelled task {task_id}: {e}") - raise - + if stderr: + stdout = stdout + "\n" + stderr if stdout else stderr + return stdout, exit_code, violation_reason + except asyncio.TimeoutError: + return "", -1, "timeout" + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to execute command: {e}") - return f"Execution error: {str(e)}", -1 + return f"Execution error: {str(e)}", -1, None def _resolve_execution_timeout(self, inputs: Dict[str, Any]) -> int: """Resolve per-task process timeout from plugin inputs.""" diff --git a/backend/secuscan/models.py b/backend/secuscan/models.py index f1792be2..4b73e1bf 100644 --- a/backend/secuscan/models.py +++ b/backend/secuscan/models.py @@ -24,6 +24,17 @@ class TaskStatus(str, Enum): COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" + TERMINATED_TIMEOUT = "terminated:timeout" + TERMINATED_MEMORY = "terminated:memory_limit" + TERMINATED_OUTPUT = "terminated:output_limit" + + +class SandboxConfig(BaseModel): + """Resource constraints applied to every plugin subprocess execution""" + timeout_seconds: int = Field(default=120, description="Max wall-clock seconds before SIGTERM") + max_memory_mb: int = Field(default=512, description="Max virtual memory in MB (RLIMIT_AS on Linux)") + max_output_bytes: int = Field(default=5_242_880, description="Max bytes captured from stdout/stderr") + allow_network: bool = Field(default=True, description="Whether subprocess can make network calls") class ScanPhase(str, Enum): @@ -83,6 +94,8 @@ class PluginMetadata(BaseModel): dependencies: Optional[Dict[str, List[str]]] = None docker_image: Optional[str] = None + sandbox: Optional[SandboxConfig] = None + checksum: Optional[str] = None signature: Optional[str] = None @@ -171,6 +184,14 @@ class PluginListResponse(BaseModel): total: int +class SandboxViolation(Exception): + """Raised when sandbox constraints are violated.""" + + def __init__(self, reason: str): + super().__init__(reason) + self.reason = reason + + class ErrorResponse(BaseModel): """Error response""" error: str diff --git a/backend/secuscan/sandbox_executor.py b/backend/secuscan/sandbox_executor.py new file mode 100644 index 00000000..138897a8 --- /dev/null +++ b/backend/secuscan/sandbox_executor.py @@ -0,0 +1,214 @@ +import asyncio +import logging +import platform +from asyncio import subprocess +from typing import List, Optional, Tuple + +from .models import SandboxConfig + +logger = logging.getLogger(__name__) + +IS_LINUX = platform.system() == "Linux" + +CHUNK_SIZE = 64 * 1024 +SIGTERM_GRACE = 3.0 + + +def resolve_sandbox_config(plugin_sandbox: Optional[SandboxConfig] = None) -> SandboxConfig: + """Merge global settings with optional per-plugin sandbox overrides.""" + from .config import settings + base = SandboxConfig( + timeout_seconds=settings.sandbox_timeout, + max_memory_mb=settings.sandbox_memory_mb, + max_output_bytes=settings.sandbox_max_output_bytes, + allow_network=settings.sandbox_allow_network, + ) + if not plugin_sandbox: + return base + overrides = plugin_sandbox.model_dump(exclude_none=True) + return base.model_copy(update=overrides) + + +def _build_preexec_fn(config: SandboxConfig): + """Build preexec_fn for Linux that applies RLIMIT_AS.""" + mem_limit = config.max_memory_mb * 1024 * 1024 + + def _apply_limits(): + import resource + resource.setrlimit(resource.RLIMIT_AS, (mem_limit, mem_limit)) + + return _apply_limits + + +def classify_memory_violation( + exit_code: int, + stderr_text: str, + rss_bytes: int, + limit_bytes: int, +) -> bool: + """Post-mortem heuristic to classify whether failure was caused by memory exhaustion.""" + if exit_code in (-11, 139): + return True + if "MemoryError" in stderr_text or "Cannot allocate memory" in stderr_text: + return True + if rss_bytes >= limit_bytes * 95 // 100 and exit_code != 0: + return True + return False + + +async def _terminate_process(process): + """Graceful SIGTERM -> 3s grace -> SIGKILL escalation. Always reaps.""" + try: + process.terminate() + except ProcessLookupError: + return + try: + await asyncio.wait_for(process.wait(), timeout=SIGTERM_GRACE) + except asyncio.TimeoutError: + try: + process.kill() + except ProcessLookupError: + pass + await process.wait() + + +async def _read_stream(stream, buffer, state, broadcast_callback=None, stream_name=""): + """Read from a stream in 64KB chunks, respecting max_output_bytes limit.""" + while True: + chunk = await stream.read(CHUNK_SIZE) + if not chunk: + break + async with state["lock"]: + if state["limit_hit"]: + return + remaining = state["max_bytes"] - state["total_bytes"] + if remaining <= 0: + state["limit_hit"] = True + return + if len(chunk) > remaining: + chunk = chunk[:remaining] + state["limit_hit"] = True + buffer.extend(chunk) + state["total_bytes"] += len(chunk) + if broadcast_callback: + await broadcast_callback(chunk, stream_name) + + +async def sandbox_execute( + cmd: List[str], + config: SandboxConfig, + broadcast_callback=None, +) -> Tuple[str, str, int, Optional[str]]: + """ + Execute a subprocess under sandbox resource constraints. + + Args: + cmd: Command list to execute. + config: SandboxConfig with timeout, memory, output limits. + When timeout_seconds is 0 or None, no wall-clock timeout is + applied internally (the caller handles it externally). + broadcast_callback: Optional async callable(chunk: bytes, stream_name: str) + invoked for each output chunk to enable live streaming. + + Returns (stdout_str, stderr_str, exit_code, violation_reason). + violation_reason is None on success, or one of + "timeout", "memory_limit", "output_limit". + """ + preexec_fn = _build_preexec_fn(config) if IS_LINUX else None + + rss_before = 0 + try: + import resource + rss_before = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + except (ImportError, AttributeError): + pass + + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=preexec_fn, + ) + + stdout_buffer = bytearray() + stderr_buffer = bytearray() + + state = { + "total_bytes": 0, + "max_bytes": config.max_output_bytes, + "limit_hit": False, + "lock": asyncio.Lock(), + } + + violation_reason = None + + reader_task = asyncio.gather( + _read_stream(process.stdout, stdout_buffer, state, broadcast_callback, "stdout"), + _read_stream(process.stderr, stderr_buffer, state, broadcast_callback, "stderr"), + ) + + try: + if config.timeout_seconds: + try: + await asyncio.wait_for(reader_task, timeout=config.timeout_seconds) + except asyncio.TimeoutError: + if state["limit_hit"]: + violation_reason = "output_limit" + else: + violation_reason = "timeout" + reader_task.cancel() + try: + await reader_task + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + await _terminate_process(process) + else: + if state["limit_hit"]: + violation_reason = "output_limit" + await _terminate_process(process) + else: + await process.wait() + else: + await reader_task + if state["limit_hit"]: + violation_reason = "output_limit" + await _terminate_process(process) + else: + await process.wait() + except asyncio.CancelledError: + violation_reason = None + if not reader_task.done(): + reader_task.cancel() + try: + await reader_task + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + raise + finally: + if process.returncode is None: + await _terminate_process(process) + + stdout_str = stdout_buffer.decode("utf-8", errors="replace") + stderr_str = stderr_buffer.decode("utf-8", errors="replace") + exit_code = process.returncode if process.returncode is not None else -1 + + if violation_reason is None and exit_code != 0: + rss_delta = 0 + try: + import resource + rss_after = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + rss_delta = rss_after - rss_before + except (ImportError, AttributeError): + pass + + if IS_LINUX: + rss_bytes = rss_delta * 1024 + else: + rss_bytes = rss_delta + + limit_bytes = config.max_memory_mb * 1024 * 1024 + + if classify_memory_violation(exit_code, stderr_str, rss_bytes, limit_bytes): + violation_reason = "memory_limit" + + return stdout_str, stderr_str, exit_code, violation_reason diff --git a/testing/backend/integration/test_chaos_execution.py b/testing/backend/integration/test_chaos_execution.py index b2196d7d..72a13055 100644 --- a/testing/backend/integration/test_chaos_execution.py +++ b/testing/backend/integration/test_chaos_execution.py @@ -236,7 +236,7 @@ async def test_upsert_failure_after_successful_scan_marks_task_failed(chaos_env) executor, "_execute_command", new_callable=AsyncMock, - return_value=(ping_stdout, 0), + return_value=(ping_stdout, 0, None), ), patch.object( executor, @@ -293,7 +293,7 @@ async def test_nonzero_exit_raw_artifact_present_task_is_failed(chaos_env): executor, "_execute_command", new_callable=AsyncMock, - return_value=(failure_output, 2), + return_value=(failure_output, 2, None), ): await executor.execute_task(task_id) diff --git a/testing/backend/integration/test_phase2_plugins.py b/testing/backend/integration/test_phase2_plugins.py index f1ff0625..2153602f 100644 --- a/testing/backend/integration/test_phase2_plugins.py +++ b/testing/backend/integration/test_phase2_plugins.py @@ -27,7 +27,7 @@ def parse_scantool_ids() -> set[str]: def run_plugin_test(test_client, plugin_id, inputs, mock_output): """Helper to run a plugin test with mocked execution.""" with patch("backend.secuscan.executor.TaskExecutor._execute_command") as mock_exec: - mock_exec.return_value = (mock_output, 0) + mock_exec.return_value = (mock_output, 0, None) payload = { "plugin_id": plugin_id, diff --git a/testing/backend/integration/test_phase3_plugins.py b/testing/backend/integration/test_phase3_plugins.py index b40fc878..eccd5b5c 100644 --- a/testing/backend/integration/test_phase3_plugins.py +++ b/testing/backend/integration/test_phase3_plugins.py @@ -18,7 +18,7 @@ def run_plugin_test(test_client, plugin_id, inputs, mock_output): """Helper to run a plugin test with mocked execution.""" with patch("backend.secuscan.executor.TaskExecutor._execute_command") as mock_exec: - mock_exec.return_value = (mock_output, 0) + mock_exec.return_value = (mock_output, 0, None) payload = { "plugin_id": plugin_id, diff --git a/testing/backend/integration/test_routes.py b/testing/backend/integration/test_routes.py index c390327a..c3425c4c 100644 --- a/testing/backend/integration/test_routes.py +++ b/testing/backend/integration/test_routes.py @@ -55,7 +55,7 @@ def test_plugin_summary(test_client): def test_start_task(test_client): """Test starting a task with a mocked executor.""" with patch("backend.secuscan.executor.TaskExecutor._execute_command") as mock_exec: - mock_exec.return_value = ("Mocked successful output", 0) + mock_exec.return_value = ("Mocked successful output", 0, None) payload = { "plugin_id": "http_inspector", diff --git a/testing/backend/test_sandbox_blocking_issues.py b/testing/backend/test_sandbox_blocking_issues.py new file mode 100644 index 00000000..f31f003b --- /dev/null +++ b/testing/backend/test_sandbox_blocking_issues.py @@ -0,0 +1,483 @@ +""" +Integration tests for blocking issues in sandbox hardening. + +Tests specifically for: +1. Memory-limit classification reliability +2. Legacy timeout argument path compatibility +3. Output-limit handling boundary precision +4. Task cancellation and process cleanup +""" + +import asyncio +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from backend.secuscan.models import SandboxConfig +from backend.secuscan.sandbox_executor import ( + sandbox_execute, + classify_memory_violation, +) + + +@pytest.mark.asyncio +async def test_timeout_enforcement_with_default(): + """Issue #2: Timeout enforcement with default timeout fallback. + + When timeout is None, should use global settings (600s default). + Verifies backward compatibility of the legacy timeout argument. + """ + from backend.secuscan.executor import TaskExecutor + + exec_ = TaskExecutor() + + # Test with explicit timeout + output, exit_code, violation = await exec_._execute_command( + [sys.executable, "-c", "import time; time.sleep(60)"], + "test-legacy-timeout", + timeout=2, + ) + assert violation == "timeout", f"Expected timeout, got {violation}" + assert exit_code != 0 + assert output == "" + + # Test with None (should use default 600s) + output2, exit_code2, violation2 = await exec_._execute_command( + [sys.executable, "-c", "print('done')"], + "test-legacy-none", + timeout=None, + ) + assert violation2 is None, f"Expected no violation, got {violation2}" + assert exit_code2 == 0 + assert "done" in output2 + + +@pytest.mark.asyncio +async def test_memory_limit_detection_comprehensive(): + """Issue #1: Memory limit detection must be reliable. + + Test all 3 conditions: + - Condition A: SIGSEGV (exit codes -11 or 139) + - Condition B: MemoryError or "Cannot allocate memory" in stderr + - Condition C: RSS >= 95% of limit AND process failed + """ + # Condition A: SIGSEGV (exit code -11) + assert classify_memory_violation(-11, "", 0, 512*1024*1024) is True + + # Condition A: SIGSEGV (exit code 139) + assert classify_memory_violation(139, "", 0, 512*1024*1024) is True + + # Condition B: MemoryError in stderr + assert classify_memory_violation(1, "MemoryError: out of memory", 0, 512*1024*1024) is True + + # Condition B: Cannot allocate memory + assert classify_memory_violation(1, "Cannot allocate memory", 0, 512*1024*1024) is True + + # Condition C: RSS at 95% threshold with failure + limit = 512 * 1024 * 1024 + assert classify_memory_violation(137, "", int(limit * 0.95), limit) is True + + # Condition C: RSS at 94% should not trigger (below threshold) + assert classify_memory_violation(137, "", int(limit * 0.94), limit) is False + + # Condition C: Success (exit_code 0) should not trigger even at high RSS + assert classify_memory_violation(0, "", int(limit * 0.99), limit) is False + + +@pytest.mark.asyncio +async def test_output_limit_exact_boundary(): + """Issue #3: Output limit must be enforced at exact byte boundary. + + Verifies that reading stops exactly at the limit and no bytes beyond. + """ + cfg = SandboxConfig(max_output_bytes=1000, timeout_seconds=30) + + # Generate more than limit to test truncation + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", "print('x' * 2000)"], + cfg, + ) + + # Total bytes (stdout + stderr) should not exceed limit + total_bytes = len(stdout.encode('utf-8')) + len(stderr.encode('utf-8')) + assert total_bytes <= 1000, f"Total bytes {total_bytes} exceeds limit of 1000" + assert violation == "output_limit" + # Exit code may be 0 if Python finished before termination signal was sent; + # output cap is the correctness criterion here. + + +@pytest.mark.asyncio +async def test_output_limit_no_partial_chunks(): + """Issue #3: Output limit prevents partial chunk overruns. + + When a chunk would exceed the limit, it must be truncated exactly. + """ + cfg = SandboxConfig(max_output_bytes=512, timeout_seconds=30) + + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", "print('A' * 1000000)"], + cfg, + ) + + stdout_bytes = len(stdout.encode('utf-8')) + stderr_bytes = len(stderr.encode('utf-8')) + total = stdout_bytes + stderr_bytes + + assert total <= 512, f"Output {total} bytes exceeds limit of 512" + assert violation == "output_limit" + + +@pytest.mark.asyncio +async def test_cancellation_with_process_cleanup(): + """Process cancellation must properly clean up child processes. + + Verifies that cancelling the task terminates the process (no orphans). + """ + cfg = SandboxConfig(timeout_seconds=30) + + task = asyncio.create_task( + sandbox_execute( + [sys.executable, "-c", "import time; time.sleep(120)"], + cfg, + ) + ) + + # Give it time to start + await asyncio.sleep(0.1) + + # Cancel the task + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + # Wait a bit and verify no zombie process + await asyncio.sleep(0.5) + + +@pytest.mark.asyncio +async def test_memory_classification_called_always(): + """Issue #1: Memory classification must be checked always. + + Verifies that we check memory violation even when exit_code == 0 + (in case RSS heuristic applies, e.g., OOM killer killed the process). + """ + cfg = SandboxConfig(timeout_seconds=30) + + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", "print('success')"], + cfg, + ) + + # Should succeed + assert exit_code == 0 + assert "success" in stdout + # Memory violation should be checked even for successful exit + assert violation is None or violation == "memory_limit" + + +@pytest.mark.asyncio +async def test_legacy_timeout_none_uses_default(): + """Issue #2: Legacy _execute_command with timeout=None must use defaults. + + Verifies backward compatibility when timeout is not specified. + """ + from backend.secuscan.executor import TaskExecutor + + exec_ = TaskExecutor() + + # Call without timeout (None) + output, exit_code, violation = await exec_._execute_command( + [sys.executable, "-c", "print('hello world')"], + "test-legacy-none2", + timeout=None, + ) + + assert exit_code == 0 + assert "hello world" in output + assert violation is None + + +@pytest.mark.asyncio +async def test_output_limit_stops_both_readers(): + """Issue #3: Output limit must stop both stdout and stderr readers. + + Verifies that shared state properly coordinates both readers. + """ + cfg = SandboxConfig(max_output_bytes=256, timeout_seconds=30) + + # Script that writes to both stdout and stderr + script = """ +import sys +for i in range(100): + print("stdout" * 10) + sys.stderr.write("stderr" * 10 + "\\n") +""" + + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", script], + cfg, + ) + + total_bytes = len(stdout.encode('utf-8')) + len(stderr.encode('utf-8')) + assert total_bytes <= 256, f"Total bytes {total_bytes} exceeds limit 256" + assert violation == "output_limit" + + +@pytest.mark.asyncio +async def test_output_limit_early_reader_termination(): + """Verify that when limit is hit, readers exit immediately. + + Tests that the check at the start of the loop prevents further reads. + """ + cfg = SandboxConfig(max_output_bytes=100, timeout_seconds=30) + + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", "print('x' * 10000)"], + cfg, + ) + + total = len(stdout.encode('utf-8')) + len(stderr.encode('utf-8')) + assert total <= 100 + assert violation == "output_limit" + + +@pytest.mark.asyncio +async def test_memory_classification_includes_exit_137(): + """Verify memory classification catches exit code 137 (OOM killer). + + Exit code 137 = 128 + 9 (SIGKILL), often from OOM killer on Linux. + """ + limit = 512 * 1024 * 1024 + + # RSS at threshold, exit 137 (SIGKILL from OOM) + assert classify_memory_violation(137, "", int(limit * 0.95), limit) is True + + # Without high RSS, exit 137 should not be classified as memory_limit + # (could be another cause) + assert classify_memory_violation(137, "", int(limit * 0.80), limit) is False + + +@pytest.mark.asyncio +async def test_live_output_broadcasting(): + """Regression: sandbox path must broadcast output chunks for live streaming.""" + from backend.secuscan.executor import TaskExecutor + from unittest.mock import AsyncMock + + exec_ = TaskExecutor() + exec_._broadcast = AsyncMock() + + await exec_._execute_command( + [sys.executable, "-c", "print('hello from live stream')"], + "test-broadcast", + timeout=30, + ) + + calls = exec_._broadcast.await_args_list + output_calls = [c for c in calls if c.args[1] == "output"] + assert len(output_calls) > 0, ( + f"Expected at least one output broadcast call, got {len(calls)} total" + ) + + all_text = "".join(c.args[2] for c in output_calls) + assert "hello from live stream" in all_text, ( + f"Broadcast output did not contain expected text: {all_text!r}" + ) + + +@pytest.mark.asyncio +async def test_stderr_captured_in_output(): + """Regression: stderr must be merged into raw output, not discarded.""" + from backend.secuscan.executor import TaskExecutor + + exec_ = TaskExecutor() + + output, exit_code, violation = await exec_._execute_command( + [ + sys.executable, "-c", + "import sys; sys.stderr.write('diagnostic info\\n'); print('stdout line')", + ], + "test-stderr-capture", + timeout=30, + ) + + assert exit_code == 0, f"Expected exit_code 0, got {exit_code}" + assert "stdout line" in output, f"Expected stdout in output: {output!r}" + assert "diagnostic info" in output, f"Expected stderr in output: {output!r}" + + +# --------------------------------------------------------------------------- +# Comprehensive precision regression tests for all 5 owner-specified categories +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_timeout_external_via_execute_command(): + """Timeout category: _execute_command applies timeout via external asyncio.wait_for + (legacy-compatible path), returns ("", -1, "timeout") on expiry.""" + from backend.secuscan.executor import TaskExecutor + + exec_ = TaskExecutor() + output, exit_code, violation = await exec_._execute_command( + [sys.executable, "-c", "import time; time.sleep(30)"], + "test-ext-timeout", + timeout=1, + ) + assert violation == "timeout", f"Expected timeout, got {violation}" + assert exit_code == -1, f"Expected exit_code -1 on timeout, got {exit_code}" + assert output == "", f"Expected empty output on timeout, got {output!r}" + + +@pytest.mark.asyncio +async def test_timeout_internal_via_sandbox_execute(): + """Timeout category: direct sandbox_execute with timeout_seconds still + applies internal timeout for callers that don't go through _execute_command.""" + cfg = SandboxConfig(timeout_seconds=1, max_memory_mb=512) + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", "import time; time.sleep(30)"], + cfg, + ) + assert violation == "timeout", f"Expected timeout, got {violation}" + assert exit_code != 0, f"Expected non-zero exit, got {exit_code}" + # stderr may have timeout noise, stdout should be empty + assert stdout == "", f"Expected empty stdout on timeout, got {stdout!r}" + + +@pytest.mark.asyncio +async def test_memory_classification_sigsegv_exit_codes(): + """Memory category: SIGSEGV signals must always classify as memory_limit.""" + limit = 512 * 1024 * 1024 + for code in (-11, 139): + assert classify_memory_violation(code, "", 0, limit) is True, ( + f"Exit code {code} should be classified as memory violation" + ) + + +@pytest.mark.asyncio +async def test_memory_classification_stderr_strings(): + """Memory category: MemoryError / Cannot allocate memory strings classify.""" + limit = 512 * 1024 * 1024 + assert classify_memory_violation(1, "MemoryError: out of memory", 0, limit) is True + assert classify_memory_violation(1, "Cannot allocate memory", 0, limit) is True + # Non-memory error should not classify just from stderr + assert classify_memory_violation(1, "Segmentation fault (core dumped)", 0, limit) is False + + +@pytest.mark.asyncio +async def test_memory_classification_rss_delta_heuristic(): + """Memory category: RSS at or above 95% threshold with non-zero exit classifies.""" + limit = 512 * 1024 * 1024 + # At threshold + assert classify_memory_violation(137, "", int(limit * 0.95), limit) is True + # Just below threshold + assert classify_memory_violation(137, "", int(limit * 0.94), limit) is False + # Above threshold but exit 0 should not classify + assert classify_memory_violation(0, "", int(limit * 0.99), limit) is False + + +@pytest.mark.asyncio +async def test_memory_classification_exit_137_with_rss(): + """Memory category: exit 137 (SIGKILL) + high RSS classifies as memory.""" + limit = 512 * 1024 * 1024 + assert classify_memory_violation(137, "", int(limit * 0.95), limit) is True + assert classify_memory_violation(137, "", int(limit * 0.80), limit) is False + + +@pytest.mark.asyncio +async def test_output_limit_lock_prevents_race(): + """Output category: asyncio.Lock prevents concurrent readers from exceeding max_bytes. + + Both stdout and stderr writers produce output concurrently. Without the lock, + the shared total_bytes could be read simultaneously by both readers, causing + both to consume up to the remaining capacity and exceed the limit. + """ + cfg = SandboxConfig(max_output_bytes=512, timeout_seconds=10) + script = ( + "import sys\n" + "for i in range(500):\n" + " sys.stdout.write('a' * 120)\n" + " sys.stderr.write('b' * 120)\n" + ) + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", script], + cfg, + ) + total = len(stdout.encode("utf-8")) + len(stderr.encode("utf-8")) + assert total <= 512, f"Lock should enforce total <= 512, got {total} bytes" + assert violation == "output_limit", f"Expected output_limit, got {violation}" + + +@pytest.mark.asyncio +async def test_output_limit_strict_boundary(): + """Output category: output is capped at exactly max_output_bytes, not rounded up.""" + for limit in (256, 511, 1023): + cfg = SandboxConfig(max_output_bytes=limit, timeout_seconds=10) + stdout, stderr, exit_code, violation = await sandbox_execute( + [sys.executable, "-c", f"print('x' * {limit * 10})"], + cfg, + ) + stdout_bytes = len(stdout.encode("utf-8")) + stderr_bytes = len(stderr.encode("utf-8")) + total = stdout_bytes + stderr_bytes + assert total <= limit, ( + f"Limit {limit}: total {total} bytes exceeds limit" + ) + + +@pytest.mark.asyncio +async def test_cancellation_raises_cancelled_error(): + """Cancellation category: cancelling a sandbox_execute task raises CancelledError + and does not leave orphan processes.""" + cfg = SandboxConfig(timeout_seconds=60) + task = asyncio.create_task( + sandbox_execute( + [sys.executable, "-c", "import time; time.sleep(60)"], + cfg, + ) + ) + await asyncio.sleep(0.2) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + await asyncio.sleep(0.5) + + +@pytest.mark.asyncio +async def test_legacy_timeout_signature_preserved(): + """Legacy compatibility: _execute_command(self, command, task_id, timeout=600) + signature must accept all three positional forms.""" + from backend.secuscan.executor import TaskExecutor + + exec_ = TaskExecutor() + + # Form 1: default timeout (=600) + output, exit_code, violation = await exec_._execute_command( + [sys.executable, "-c", "print('ok')"], + "test-legacy-default", + ) + assert exit_code == 0 + assert "ok" in output + assert violation is None + + # Form 2: explicit timeout + output, exit_code, violation = await exec_._execute_command( + [sys.executable, "-c", "print('ok2')"], + "test-legacy-explicit", + timeout=30, + ) + assert exit_code == 0 + assert "ok2" in output + assert violation is None + + # Form 3: timeout=None falls back to settings.sandbox_timeout + output, exit_code, violation = await exec_._execute_command( + [sys.executable, "-c", "print('ok3')"], + "test-legacy-none", + timeout=None, + ) + assert exit_code == 0 + assert "ok3" in output + assert violation is None diff --git a/testing/backend/test_sandbox_executor.py b/testing/backend/test_sandbox_executor.py new file mode 100644 index 00000000..3248200e --- /dev/null +++ b/testing/backend/test_sandbox_executor.py @@ -0,0 +1,232 @@ +import asyncio +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from backend.secuscan.models import SandboxConfig +from backend.secuscan.sandbox_executor import ( + sandbox_execute, + _terminate_process, + _build_preexec_fn, + classify_memory_violation, +) + + +@pytest.mark.asyncio +async def test_legacy_timeout_compatibility(): + """Test Case 1: Legacy Timeout Compatibility. + + _execute_command(cmd, timeout=1) must apply the timeout and return + violation_reason "timeout". + """ + from backend.secuscan.executor import TaskExecutor + + exec_ = TaskExecutor() + output, exit_code, violation_reason = await exec_._execute_command( + [sys.executable, "-c", "import time; time.sleep(30)"], + "test-legacy-timeout", + timeout=1, + ) + assert violation_reason == "timeout" + assert exit_code != 0 + + +@pytest.mark.asyncio +async def test_signal_escalation(): + """Test Case 2: Signal Escalation. + + When a process ignores SIGTERM, verify that terminate() is called first, + then kill() after the grace period, and process.wait() is called twice (reap). + """ + mock_process = MagicMock() + mock_process.returncode = None + mock_process.terminate = MagicMock() + mock_process.kill = MagicMock() + + wait_count = 0 + + async def wait_side_effect(): + nonlocal wait_count + wait_count += 1 + if wait_count == 1: + await asyncio.sleep(999) + + mock_process.wait = wait_side_effect + + with patch("backend.secuscan.sandbox_executor.SIGTERM_GRACE", 0.05): + await _terminate_process(mock_process) + + mock_process.terminate.assert_called_once() + mock_process.kill.assert_called_once() + assert wait_count == 2 + + +class TestMemoryLimitClassification: + """Test Case 3: Memory Limit Classification.""" + + @pytest.mark.parametrize("exit_code", [-11, 139]) + def test_sigsegv(self, exit_code): + assert classify_memory_violation( + exit_code=exit_code, + stderr_text="", + rss_bytes=0, + limit_bytes=512 * 1024 * 1024, + ) is True + + def test_memory_error_string(self): + assert classify_memory_violation( + exit_code=1, + stderr_text="MemoryError: unable to allocate", + rss_bytes=0, + limit_bytes=512 * 1024 * 1024, + ) is True + + def test_cannot_allocate_memory(self): + assert classify_memory_violation( + exit_code=1, + stderr_text="Cannot allocate memory", + rss_bytes=0, + limit_bytes=512 * 1024 * 1024, + ) is True + + def test_rss_heuristic(self): + limit_bytes = 512 * 1024 * 1024 + assert classify_memory_violation( + exit_code=137, + stderr_text="", + rss_bytes=limit_bytes, + limit_bytes=limit_bytes, + ) is True + + def test_rss_below_threshold(self): + limit_bytes = 512 * 1024 * 1024 + assert classify_memory_violation( + exit_code=1, + stderr_text="", + rss_bytes=int(limit_bytes * 0.50), + limit_bytes=limit_bytes, + ) is False + + def test_zero_exit_not_classified(self): + limit_bytes = 512 * 1024 * 1024 + assert classify_memory_violation( + exit_code=0, + stderr_text="", + rss_bytes=int(limit_bytes * 0.99), + limit_bytes=limit_bytes, + ) is False + + +@pytest.mark.asyncio +async def test_proactive_output_truncation(): + """Test Case 4: Proactive Output Truncation. + + When subprocess output exceeds max_output_bytes, reading must stop + at the boundary, the process terminated, and violation_reason returned. + """ + cfg = SandboxConfig(max_output_bytes=1024, timeout_seconds=30) + stdout, stderr, exit_code, violation_reason = await sandbox_execute( + [sys.executable, "-c", "print('A' * 10000000)"], + cfg, + ) + assert violation_reason == "output_limit" + assert len(stdout) <= 2048 + assert exit_code != 0 + + +@pytest.mark.asyncio +async def test_task_cancellation_safety(): + """Test Case 5: Task Cancellation Safety. + + If the parent coroutine is cancelled, the subprocess must be + terminated and reaped, never orphaned. + """ + cfg = SandboxConfig(timeout_seconds=30) + task = asyncio.create_task( + sandbox_execute( + [sys.executable, "-c", "import time; time.sleep(60)"], + cfg, + ) + ) + await asyncio.sleep(0.2) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_platform_guard_non_linux(): + """Test Case 6: Platform Guard Verification. + + On Linux, preexec_fn applies RLIMIT_AS. On other platforms, + it must be None (checked at call site in sandbox_execute). + Timeout and output limits must remain active on all platforms. + """ + built = _build_preexec_fn(SandboxConfig(max_memory_mb=128)) + assert callable(built) + + cfg = SandboxConfig(max_output_bytes=100, timeout_seconds=30) + stdout, stderr, exit_code, violation_reason = await sandbox_execute( + [sys.executable, "-c", "print('x' * 5000)"], + cfg, + ) + assert violation_reason == "output_limit" + assert len(stdout) < 500 + + cfg2 = SandboxConfig(timeout_seconds=1) + stdout2, stderr2, exit_code2, vr2 = await sandbox_execute( + [sys.executable, "-c", "import time; time.sleep(30)"], + cfg2, + ) + assert vr2 == "timeout" + + +@pytest.mark.asyncio +async def test_sandbox_execute_normal_completion(): + cfg = SandboxConfig(timeout_seconds=30) + stdout, stderr, exit_code, violation_reason = await sandbox_execute( + [sys.executable, "-c", "print('hello world')"], + cfg, + ) + assert "hello world" in stdout + assert exit_code == 0 + assert violation_reason is None + + +def test_sandbox_violation_exception(): + from backend.secuscan.models import SandboxViolation + exc = SandboxViolation("timeout") + assert exc.reason == "timeout" + assert str(exc) == "timeout" + + +@pytest.mark.asyncio +async def test_resolve_sandbox_config_global_defaults(monkeypatch): + from backend.secuscan.sandbox_executor import resolve_sandbox_config + monkeypatch.setattr( + "backend.secuscan.config.settings.sandbox_timeout", + 42, + ) + monkeypatch.setattr( + "backend.secuscan.config.settings.sandbox_memory_mb", + 256, + ) + resolved = resolve_sandbox_config(None) + assert resolved.timeout_seconds == 42 + assert resolved.max_memory_mb == 256 + assert resolved.max_output_bytes == 5_242_880 + + +@pytest.mark.asyncio +async def test_resolve_sandbox_config_plugin_overrides(): + from backend.secuscan.sandbox_executor import resolve_sandbox_config + resolved = resolve_sandbox_config( + SandboxConfig(timeout_seconds=999, max_memory_mb=2048) + ) + assert resolved.timeout_seconds == 999 + assert resolved.max_memory_mb == 2048 + assert resolved.max_output_bytes == 5_242_880 diff --git a/testing/backend/unit/test_executor.py b/testing/backend/unit/test_executor.py index f5a45c0f..4be80c3c 100644 --- a/testing/backend/unit/test_executor.py +++ b/testing/backend/unit/test_executor.py @@ -276,7 +276,7 @@ async def test_execute_task_releases_limiter_on_normal_completion(setup_test_env executor = TaskExecutor() async def fake_command(*args, **kwargs): - return "80/tcp open http", 0 + return "80/tcp open http", 0, None with patch.object(executor, "_execute_command", side_effect=fake_command), \ patch("backend.secuscan.executor.concurrent_limiter") as mock_limiter, \