From d9c173e918716df6e1fb69b579ca04f070ac4da1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 25 Nov 2025 23:09:45 +0000 Subject: [PATCH 1/8] Initial plan From f0988a3b0984ad4380f6c98ab881f68da09d7900 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 25 Nov 2025 23:19:25 +0000 Subject: [PATCH 2/8] Add DOE lab email domain check middleware with feature flag Co-authored-by: garland3 <1162675+garland3@users.noreply.github.com> --- .env.example | 1 + backend/core/doe_lab_middleware.py | 99 ++++++ backend/main.py | 7 + backend/modules/config/config_manager.py | 6 + backend/tests/test_doe_lab_feature_flag.py | 106 +++++++ backend/tests/test_doe_lab_middleware.py | 334 +++++++++++++++++++++ 6 files changed, 553 insertions(+) create mode 100644 backend/core/doe_lab_middleware.py create mode 100644 backend/tests/test_doe_lab_feature_flag.py create mode 100644 backend/tests/test_doe_lab_middleware.py diff --git a/.env.example b/.env.example index a36dc89..4148c6d 100644 --- a/.env.example +++ b/.env.example @@ -77,6 +77,7 @@ FEATURE_FILES_PANEL_ENABLED=true # Uploaded/session files panel FEATURE_CHAT_HISTORY_ENABLED=false # Previous chat history list FEATURE_COMPLIANCE_LEVELS_ENABLED=false # Compliance level filtering for MCP servers and data sources FEATURE_SPLASH_SCREEN_ENABLED=false # Startup splash screen for displaying policies and information +FEATURE_DOE_LAB_CHECK_ENABLED=false # Restrict access to DOE/NNSA/DOE lab email domains # (Adjust above to stage rollouts. For a bare-bones chat set them all to false.) diff --git a/backend/core/doe_lab_middleware.py b/backend/core/doe_lab_middleware.py new file mode 100644 index 0000000..2d366ed --- /dev/null +++ b/backend/core/doe_lab_middleware.py @@ -0,0 +1,99 @@ +"""DOE lab email domain validation middleware. + +This middleware enforces that users must have email addresses from DOE, NNSA, +or DOE national laboratory domains. It can be enabled/disabled via the +FEATURE_DOE_LAB_CHECK_ENABLED feature flag. +""" + +import logging +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse, RedirectResponse, Response + +logger = logging.getLogger(__name__) + + +class DOELabMiddleware(BaseHTTPMiddleware): + """Middleware to enforce DOE/NNSA/DOE lab email domain restrictions.""" + + # Comprehensive list of DOE, NNSA, and DOE national laboratory domains + DOE_LAB_DOMAINS = frozenset([ + # HQ / NNSA / DOE-wide + "doe.gov", "nnsa.doe.gov", "hq.doe.gov", + # National labs (broad coverage) + "anl.gov", "bnl.gov", "fnal.gov", "inl.gov", "lbl.gov", "lanl.gov", + "llnl.gov", "ornl.gov", "pnnl.gov", "sandia.gov", "srnl.doe.gov", + "ameslab.gov", "jlab.org", "princeton.edu", "slac.stanford.edu", + "pppl.gov", "nrel.gov", "netl.doe.gov", "stanford.edu", + ]) + + def __init__(self, app, auth_redirect_url: str = "/auth"): + """Initialize DOE lab middleware. + + Args: + app: ASGI application + auth_redirect_url: URL to redirect to on auth failure (default: /auth) + """ + super().__init__(app) + self.auth_redirect_url = auth_redirect_url + + async def dispatch(self, request: Request, call_next) -> Response: + """Check if user email is from DOE/lab domain. + + Args: + request: Incoming HTTP request + call_next: Next middleware/handler in chain + + Returns: + Response from next handler if authorized, or 403/redirect if not + """ + # Skip check for health endpoint and auth redirect endpoint + if request.url.path == '/api/health' or request.url.path == self.auth_redirect_url: + return await call_next(request) + + # Get email from request state (set by AuthMiddleware) + email = getattr(request.state, "user_email", None) + + if not email or "@" not in email: + logger.warning(f"DOE check failed: missing or invalid email") + return self._unauthorized_response(request, "User email required") + + # Extract domain and check against allowed list + domain = email.split("@", 1)[1].lower() + if not self._is_doe_domain(domain): + logger.warning(f"DOE check failed: unauthorized domain {domain}") + return self._unauthorized_response( + request, + "Access restricted to DOE / NNSA / DOE labs" + ) + + return await call_next(request) + + def _is_doe_domain(self, domain: str) -> bool: + """Check if domain is a DOE/lab domain or subdomain. + + Args: + domain: Email domain to check + + Returns: + True if domain is authorized, False otherwise + """ + # Direct match or subdomain match (e.g., foo.sandia.gov matches sandia.gov) + return any(domain == d or domain.endswith("." + d) for d in self.DOE_LAB_DOMAINS) + + def _unauthorized_response(self, request: Request, detail: str) -> Response: + """Return appropriate unauthorized response based on endpoint type. + + Args: + request: Incoming HTTP request + detail: Error detail message + + Returns: + JSONResponse for API endpoints, RedirectResponse for others + """ + if request.url.path.startswith('/api/'): + return JSONResponse( + status_code=403, + content={"detail": detail} + ) + return RedirectResponse(url=self.auth_redirect_url, status_code=302) diff --git a/backend/main.py b/backend/main.py index 6523396..198dbcc 100644 --- a/backend/main.py +++ b/backend/main.py @@ -27,6 +27,7 @@ from core.middleware import AuthMiddleware from core.rate_limit_middleware import RateLimitMiddleware from core.security_headers_middleware import SecurityHeadersMiddleware +from core.doe_lab_middleware import DOELabMiddleware from core.otel_config import setup_opentelemetry from core.utils import sanitize_for_logging from core.auth import get_user_from_header @@ -132,6 +133,12 @@ async def lifespan(app: FastAPI): """ app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(RateLimitMiddleware) +# DOE lab domain check (if enabled) - add before Auth so it runs after +if config.app_settings.feature_doe_lab_check_enabled: + app.add_middleware( + DOELabMiddleware, + auth_redirect_url=config.app_settings.auth_redirect_url + ) app.add_middleware( AuthMiddleware, debug_mode=config.app_settings.debug_mode, diff --git a/backend/modules/config/config_manager.py b/backend/modules/config/config_manager.py index 255309f..335ac49 100644 --- a/backend/modules/config/config_manager.py +++ b/backend/modules/config/config_manager.py @@ -285,6 +285,12 @@ def agent_mode_available(self) -> bool: description="Enable compliance level filtering for MCP servers and data sources", validation_alias=AliasChoices("FEATURE_COMPLIANCE_LEVELS_ENABLED"), ) + # DOE lab email domain check feature gate + feature_doe_lab_check_enabled: bool = Field( + False, + description="Enable DOE/NNSA/DOE lab email domain restriction", + validation_alias=AliasChoices("FEATURE_DOE_LAB_CHECK_ENABLED"), + ) # Capability tokens (for headless access to downloads/iframes) capability_token_secret: str = "" diff --git a/backend/tests/test_doe_lab_feature_flag.py b/backend/tests/test_doe_lab_feature_flag.py new file mode 100644 index 0000000..f8e1c09 --- /dev/null +++ b/backend/tests/test_doe_lab_feature_flag.py @@ -0,0 +1,106 @@ +"""Integration test for DOE lab middleware feature flag.""" + +import pytest +from fastapi import FastAPI, Request +from starlette.testclient import TestClient +from starlette.responses import Response +from starlette.middleware.base import BaseHTTPMiddleware + +from core.doe_lab_middleware import DOELabMiddleware + + +class SimpleAuthMiddleware(BaseHTTPMiddleware): + """Simplified auth middleware for testing that just sets user_email in state.""" + + def __init__(self, app, debug_mode: bool = False): + super().__init__(app) + self.debug_mode = debug_mode + + async def dispatch(self, request: Request, call_next) -> Response: + # Simulate setting user_email from header (like real AuthMiddleware) + email = request.headers.get("X-User-Email") + if email: + request.state.user_email = email + elif self.debug_mode: + request.state.user_email = "test@test.com" + + return await call_next(request) + + +def test_doe_middleware_not_active_when_disabled(): + """Test that DOE middleware allows non-DOE emails when not added to app.""" + # Create app without DOE middleware + app = FastAPI() + + @app.get("/api/test") + def test_endpoint(request: Request): + return {"user": getattr(request.state, "user_email", "none")} + + # Add only simple auth middleware (which sets user_email in state) + app.add_middleware(SimpleAuthMiddleware, debug_mode=True) + + # DOE middleware NOT added (simulating feature flag disabled) + + # Create client + client = TestClient(app) + + # Test with non-DOE email - should pass because DOE middleware is not active + response = client.get("/api/test", headers={"X-User-Email": "test@gmail.com"}) + assert response.status_code == 200 + assert response.json()["user"] == "test@gmail.com" + + +def test_doe_middleware_active_when_enabled(): + """Test that DOE middleware blocks non-DOE emails when added to app.""" + # Create app with DOE middleware + app = FastAPI() + + @app.get("/api/test") + def test_endpoint(request: Request): + return {"user": request.state.user_email} + + @app.get("/auth") + def auth_endpoint(): + return {"login": True} + + # Add DOE middleware first (will run second after auth) + app.add_middleware(DOELabMiddleware) + + # Add auth middleware second (will run first, setting email) + app.add_middleware(SimpleAuthMiddleware, debug_mode=True) + + # Create client + client = TestClient(app) + + # Test with non-DOE email - should be rejected + response = client.get("/api/test", headers={"X-User-Email": "test@gmail.com"}) + assert response.status_code == 403 + assert "Access restricted" in response.json()["detail"] + + # Test with valid DOE email - should pass + response = client.get("/api/test", headers={"X-User-Email": "test@sandia.gov"}) + assert response.status_code == 200 + assert response.json()["user"] == "test@sandia.gov" + + +def test_middleware_ordering_auth_before_doe(): + """Test that auth middleware must run before DOELabMiddleware.""" + # This test verifies the correct middleware ordering + app = FastAPI() + + @app.get("/api/test") + def test_endpoint(request: Request): + return {"user": request.state.user_email} + + # Add DOE first, Auth second (so Auth runs first in request flow) + app.add_middleware(DOELabMiddleware) + app.add_middleware(SimpleAuthMiddleware, debug_mode=True) + + client = TestClient(app) + + # Auth middleware should set the email, then DOELabMiddleware should check it + response = client.get("/api/test", headers={"X-User-Email": "test@lanl.gov"}) + assert response.status_code == 200 + assert response.json()["user"] == "test@lanl.gov" + + diff --git a/backend/tests/test_doe_lab_middleware.py b/backend/tests/test_doe_lab_middleware.py new file mode 100644 index 0000000..2223bd4 --- /dev/null +++ b/backend/tests/test_doe_lab_middleware.py @@ -0,0 +1,334 @@ +"""Tests for DOE lab email domain middleware.""" + +import pytest +from fastapi import FastAPI, Request +from starlette.testclient import TestClient + +from core.doe_lab_middleware import DOELabMiddleware + + +@pytest.fixture +def app(): + """Create a test FastAPI app with DOE middleware.""" + app = FastAPI() + + @app.get("/api/test") + def api_test(request: Request): + return {"user": request.state.user_email} + + @app.get("/test") + def test(request: Request): + return {"user": request.state.user_email} + + @app.get("/auth") + def auth(): + return {"login": True} + + @app.get("/api/health") + def health(): + return {"status": "ok"} + + # Add DOE middleware + app.add_middleware(DOELabMiddleware) + + return app + + +@pytest.fixture +def client(app): + """Create a test client.""" + return TestClient(app) + + +class TestDOELabMiddleware: + """Test DOE lab middleware domain validation.""" + + # Valid DOE/lab email domains + @pytest.mark.parametrize("email", [ + "user@doe.gov", + "user@nnsa.doe.gov", + "user@hq.doe.gov", + "user@anl.gov", + "user@bnl.gov", + "user@lanl.gov", + "user@llnl.gov", + "user@sandia.gov", + "user@ornl.gov", + "user@pnnl.gov", + "user@lbl.gov", + "user@nrel.gov", + "user@stanford.edu", # SLAC + "user@jlab.org", + "user@pppl.gov", + # Subdomain tests + "user@sub.sandia.gov", + "user@mail.doe.gov", + "user@dept.lanl.gov", + ]) + def test_valid_doe_emails_allowed(self, client, email): + """Test that valid DOE/lab emails are allowed through.""" + # Mock the request state with email (normally set by AuthMiddleware) + def add_user_email(request, call_next): + request.state.user_email = email + return call_next(request) + + # Inject the email into request state + response = client.get( + "/api/test", + headers={"X-User-Email": email} + ) + # Since we're not actually setting request.state in the test, + # we'll get through the middleware but fail at the endpoint level. + # Let's test differently - by directly checking middleware logic. + + def test_valid_doe_email_via_state(self, app): + """Test valid DOE email passes through middleware.""" + from starlette.middleware.base import RequestResponseEndpoint + from starlette.requests import Request + from starlette.responses import Response + + middleware = DOELabMiddleware(app) + + # Create a mock request with state + async def call_next(request): + return Response("OK", status_code=200) + + # Test with valid email + async def test_request(): + from starlette.datastructures import URL + scope = { + "type": "http", + "method": "GET", + "path": "/api/test", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = "test@sandia.gov" + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 200 + + import asyncio + asyncio.run(test_request()) + + # Invalid email domains + @pytest.mark.parametrize("email", [ + "user@gmail.com", + "user@yahoo.com", + "user@example.com", + "user@company.com", + "user@malicious.com", + # Similar but not exact matches + "user@fakedoe.gov", + "user@sandia.com", # Wrong TLD + "user@doe.org", # Wrong TLD + ]) + def test_invalid_emails_rejected(self, email): + """Test that non-DOE emails are rejected.""" + from starlette.requests import Request + from starlette.responses import Response + + app = FastAPI() + middleware = DOELabMiddleware(app) + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/test", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = email + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 403 + assert b"Access restricted" in response.body + + import asyncio + asyncio.run(test_request()) + + def test_missing_email_rejected(self): + """Test that requests without email are rejected.""" + from starlette.requests import Request + from starlette.responses import Response + + app = FastAPI() + middleware = DOELabMiddleware(app) + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/test", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + # No email set in state + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 403 + + import asyncio + asyncio.run(test_request()) + + def test_health_endpoint_bypassed(self): + """Test that /api/health endpoint bypasses DOE check.""" + from starlette.requests import Request + from starlette.responses import Response + + app = FastAPI() + middleware = DOELabMiddleware(app) + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/health", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + # No email - should still pass for health check + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 200 + + import asyncio + asyncio.run(test_request()) + + def test_auth_endpoint_bypassed(self): + """Test that /auth endpoint bypasses DOE check.""" + from starlette.requests import Request + from starlette.responses import Response + + app = FastAPI() + middleware = DOELabMiddleware(app, auth_redirect_url="/auth") + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/auth", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + # No email - should still pass for auth endpoint + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 200 + + import asyncio + asyncio.run(test_request()) + + def test_api_endpoint_returns_json_error(self): + """Test that API endpoints get JSON error response.""" + from starlette.requests import Request + from starlette.responses import Response + + app = FastAPI() + middleware = DOELabMiddleware(app) + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/something", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = "bad@gmail.com" + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 403 + assert response.headers["content-type"] == "application/json" + + import asyncio + asyncio.run(test_request()) + + def test_non_api_endpoint_redirects(self): + """Test that non-API endpoints get redirected on failure.""" + from starlette.requests import Request + from starlette.responses import Response + + app = FastAPI() + middleware = DOELabMiddleware(app, auth_redirect_url="/custom-auth") + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/something", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = "bad@gmail.com" + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 302 + assert response.headers["location"] == "/custom-auth" + + import asyncio + asyncio.run(test_request()) + + def test_is_doe_domain_direct_match(self): + """Test _is_doe_domain method with direct matches.""" + app = FastAPI() + middleware = DOELabMiddleware(app) + + assert middleware._is_doe_domain("sandia.gov") is True + assert middleware._is_doe_domain("doe.gov") is True + assert middleware._is_doe_domain("lanl.gov") is True + assert middleware._is_doe_domain("llnl.gov") is True + + def test_is_doe_domain_subdomain_match(self): + """Test _is_doe_domain method with subdomain matches.""" + app = FastAPI() + middleware = DOELabMiddleware(app) + + assert middleware._is_doe_domain("mail.sandia.gov") is True + assert middleware._is_doe_domain("sub.doe.gov") is True + assert middleware._is_doe_domain("dept.lanl.gov") is True + assert middleware._is_doe_domain("team.llnl.gov") is True + + def test_is_doe_domain_invalid(self): + """Test _is_doe_domain method with invalid domains.""" + app = FastAPI() + middleware = DOELabMiddleware(app) + + assert middleware._is_doe_domain("gmail.com") is False + assert middleware._is_doe_domain("example.com") is False + assert middleware._is_doe_domain("sandia.com") is False # Wrong TLD + assert middleware._is_doe_domain("doe.org") is False # Wrong TLD + assert middleware._is_doe_domain("fakedoe.gov") is False From 2092ede622399a2b48c5dd47e28f6c8c68240ece Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:41:25 +0000 Subject: [PATCH 3/8] Refactor to generic domain whitelist with config file Co-authored-by: garland3 <1162675+garland3@users.noreply.github.com> --- .env.example | 2 +- backend/core/domain_whitelist.py | 145 ++++++++ ...ware.py => domain_whitelist_middleware.py} | 63 ++-- backend/main.py | 8 +- backend/modules/config/config_manager.py | 8 +- backend/tests/test_doe_lab_feature_flag.py | 106 ------ backend/tests/test_doe_lab_middleware.py | 334 ------------------ backend/tests/test_domain_whitelist.py | 270 ++++++++++++++ config/defaults/domain-whitelist-example.json | 19 + config/defaults/domain-whitelist.json | 119 +++++++ 10 files changed, 588 insertions(+), 486 deletions(-) create mode 100644 backend/core/domain_whitelist.py rename backend/core/{doe_lab_middleware.py => domain_whitelist_middleware.py} (55%) delete mode 100644 backend/tests/test_doe_lab_feature_flag.py delete mode 100644 backend/tests/test_doe_lab_middleware.py create mode 100644 backend/tests/test_domain_whitelist.py create mode 100644 config/defaults/domain-whitelist-example.json create mode 100644 config/defaults/domain-whitelist.json diff --git a/.env.example b/.env.example index 4148c6d..849bc9f 100644 --- a/.env.example +++ b/.env.example @@ -77,7 +77,7 @@ FEATURE_FILES_PANEL_ENABLED=true # Uploaded/session files panel FEATURE_CHAT_HISTORY_ENABLED=false # Previous chat history list FEATURE_COMPLIANCE_LEVELS_ENABLED=false # Compliance level filtering for MCP servers and data sources FEATURE_SPLASH_SCREEN_ENABLED=false # Startup splash screen for displaying policies and information -FEATURE_DOE_LAB_CHECK_ENABLED=false # Restrict access to DOE/NNSA/DOE lab email domains +FEATURE_DOMAIN_WHITELIST_ENABLED=false # Restrict access to whitelisted email domains (config/defaults/domain-whitelist.json) # (Adjust above to stage rollouts. For a bare-bones chat set them all to false.) diff --git a/backend/core/domain_whitelist.py b/backend/core/domain_whitelist.py new file mode 100644 index 0000000..e720e93 --- /dev/null +++ b/backend/core/domain_whitelist.py @@ -0,0 +1,145 @@ +""" +Domain whitelist management for email access control. + +Loads domain whitelist definitions from domain-whitelist.json and provides +validation for user email domains. +""" + +import json +import logging +from pathlib import Path +from typing import List, Optional, Set +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class DomainWhitelistConfig: + """Configuration for domain whitelist.""" + enabled: bool + domains: Set[str] + subdomain_matching: bool + version: str + description: str + + +class DomainWhitelistManager: + """Manages domain whitelist configuration and validation.""" + + def __init__(self, config_path: Optional[Path] = None): + """Initialize the domain whitelist manager. + + Args: + config_path: Path to domain-whitelist.json. If None, uses default location. + """ + self.config: Optional[DomainWhitelistConfig] = None + + if config_path is None: + # Try to find config in standard locations + backend_root = Path(__file__).parent.parent + project_root = backend_root.parent + + search_paths = [ + project_root / "config" / "overrides" / "domain-whitelist.json", + project_root / "config" / "defaults" / "domain-whitelist.json", + backend_root / "configfilesadmin" / "domain-whitelist.json", + backend_root / "configfiles" / "domain-whitelist.json", + ] + + for path in search_paths: + if path.exists(): + config_path = path + break + + if config_path and config_path.exists(): + self._load_config(config_path) + else: + logger.warning("No domain-whitelist.json found, domain whitelist disabled") + self.config = DomainWhitelistConfig( + enabled=False, + domains=set(), + subdomain_matching=True, + version="1.0", + description="No config loaded" + ) + + def _load_config(self, config_path: Path): + """Load domain whitelist configuration from JSON file.""" + try: + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + + # Extract domains from the list of domain objects + domains = set() + for domain_entry in config_data.get('domains', []): + if isinstance(domain_entry, dict): + domains.add(domain_entry.get('domain', '').lower()) + elif isinstance(domain_entry, str): + domains.add(domain_entry.lower()) + + self.config = DomainWhitelistConfig( + enabled=config_data.get('enabled', False), + domains=domains, + subdomain_matching=config_data.get('subdomain_matching', True), + version=config_data.get('version', '1.0'), + description=config_data.get('description', '') + ) + + logger.info(f"Loaded {len(self.config.domains)} domains from {config_path}") + logger.debug(f"Domain whitelist enabled: {self.config.enabled}") + + except Exception as e: + logger.error(f"Error loading domain-whitelist.json: {e}") + # Use disabled config on error + self.config = DomainWhitelistConfig( + enabled=False, + domains=set(), + subdomain_matching=True, + version="1.0", + description="Error loading config" + ) + + def is_enabled(self) -> bool: + """Check if domain whitelist is enabled. + + Returns: + True if enabled, False otherwise + """ + return self.config is not None and self.config.enabled + + def is_domain_allowed(self, email: str) -> bool: + """Check if an email address is from an allowed domain. + + Args: + email: Email address to validate + + Returns: + True if domain is allowed, False otherwise + """ + if not self.config or not self.config.enabled: + # If not enabled or no config, allow all + return True + + if not email or "@" not in email: + return False + + domain = email.split("@", 1)[1].lower() + + # Check if domain is in whitelist + if domain in self.config.domains: + return True + + # Check subdomains if enabled + if self.config.subdomain_matching: + return any(domain.endswith("." + d) for d in self.config.domains) + + return False + + def get_domains(self) -> Set[str]: + """Get the set of whitelisted domains. + + Returns: + Set of allowed domains + """ + return self.config.domains if self.config else set() diff --git a/backend/core/doe_lab_middleware.py b/backend/core/domain_whitelist_middleware.py similarity index 55% rename from backend/core/doe_lab_middleware.py rename to backend/core/domain_whitelist_middleware.py index 2d366ed..74b3714 100644 --- a/backend/core/doe_lab_middleware.py +++ b/backend/core/domain_whitelist_middleware.py @@ -1,8 +1,8 @@ -"""DOE lab email domain validation middleware. +"""Email domain whitelist validation middleware. -This middleware enforces that users must have email addresses from DOE, NNSA, -or DOE national laboratory domains. It can be enabled/disabled via the -FEATURE_DOE_LAB_CHECK_ENABLED feature flag. +This middleware enforces that users must have email addresses from whitelisted +domains. Configuration is loaded from domain-whitelist.json and can be +enabled/disabled via the FEATURE_DOMAIN_WHITELIST_ENABLED feature flag. """ import logging @@ -10,25 +10,16 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse, RedirectResponse, Response +from core.domain_whitelist import DomainWhitelistManager + logger = logging.getLogger(__name__) -class DOELabMiddleware(BaseHTTPMiddleware): - """Middleware to enforce DOE/NNSA/DOE lab email domain restrictions.""" - - # Comprehensive list of DOE, NNSA, and DOE national laboratory domains - DOE_LAB_DOMAINS = frozenset([ - # HQ / NNSA / DOE-wide - "doe.gov", "nnsa.doe.gov", "hq.doe.gov", - # National labs (broad coverage) - "anl.gov", "bnl.gov", "fnal.gov", "inl.gov", "lbl.gov", "lanl.gov", - "llnl.gov", "ornl.gov", "pnnl.gov", "sandia.gov", "srnl.doe.gov", - "ameslab.gov", "jlab.org", "princeton.edu", "slac.stanford.edu", - "pppl.gov", "nrel.gov", "netl.doe.gov", "stanford.edu", - ]) +class DomainWhitelistMiddleware(BaseHTTPMiddleware): + """Middleware to enforce email domain whitelist restrictions.""" def __init__(self, app, auth_redirect_url: str = "/auth"): - """Initialize DOE lab middleware. + """Initialize domain whitelist middleware. Args: app: ASGI application @@ -36,9 +27,15 @@ def __init__(self, app, auth_redirect_url: str = "/auth"): """ super().__init__(app) self.auth_redirect_url = auth_redirect_url + self.whitelist_manager = DomainWhitelistManager() + + if self.whitelist_manager.is_enabled(): + logger.info(f"Domain whitelist enabled with {len(self.whitelist_manager.get_domains())} domains") + else: + logger.info("Domain whitelist disabled") async def dispatch(self, request: Request, call_next) -> Response: - """Check if user email is from DOE/lab domain. + """Check if user email is from a whitelisted domain. Args: request: Incoming HTTP request @@ -51,36 +48,28 @@ async def dispatch(self, request: Request, call_next) -> Response: if request.url.path == '/api/health' or request.url.path == self.auth_redirect_url: return await call_next(request) + # If whitelist is not enabled in config, allow all + if not self.whitelist_manager.is_enabled(): + return await call_next(request) + # Get email from request state (set by AuthMiddleware) email = getattr(request.state, "user_email", None) if not email or "@" not in email: - logger.warning(f"DOE check failed: missing or invalid email") + logger.warning("Domain whitelist check failed: missing or invalid email") return self._unauthorized_response(request, "User email required") - # Extract domain and check against allowed list - domain = email.split("@", 1)[1].lower() - if not self._is_doe_domain(domain): - logger.warning(f"DOE check failed: unauthorized domain {domain}") + # Check if domain is allowed + if not self.whitelist_manager.is_domain_allowed(email): + domain = email.split("@", 1)[1].lower() + logger.warning(f"Domain whitelist check failed: unauthorized domain {domain}") return self._unauthorized_response( request, - "Access restricted to DOE / NNSA / DOE labs" + "Access restricted to whitelisted domains" ) return await call_next(request) - def _is_doe_domain(self, domain: str) -> bool: - """Check if domain is a DOE/lab domain or subdomain. - - Args: - domain: Email domain to check - - Returns: - True if domain is authorized, False otherwise - """ - # Direct match or subdomain match (e.g., foo.sandia.gov matches sandia.gov) - return any(domain == d or domain.endswith("." + d) for d in self.DOE_LAB_DOMAINS) - def _unauthorized_response(self, request: Request, detail: str) -> Response: """Return appropriate unauthorized response based on endpoint type. diff --git a/backend/main.py b/backend/main.py index 198dbcc..6e785fe 100644 --- a/backend/main.py +++ b/backend/main.py @@ -27,7 +27,7 @@ from core.middleware import AuthMiddleware from core.rate_limit_middleware import RateLimitMiddleware from core.security_headers_middleware import SecurityHeadersMiddleware -from core.doe_lab_middleware import DOELabMiddleware +from core.domain_whitelist_middleware import DomainWhitelistMiddleware from core.otel_config import setup_opentelemetry from core.utils import sanitize_for_logging from core.auth import get_user_from_header @@ -133,10 +133,10 @@ async def lifespan(app: FastAPI): """ app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(RateLimitMiddleware) -# DOE lab domain check (if enabled) - add before Auth so it runs after -if config.app_settings.feature_doe_lab_check_enabled: +# Domain whitelist check (if enabled) - add before Auth so it runs after +if config.app_settings.feature_domain_whitelist_enabled: app.add_middleware( - DOELabMiddleware, + DomainWhitelistMiddleware, auth_redirect_url=config.app_settings.auth_redirect_url ) app.add_middleware( diff --git a/backend/modules/config/config_manager.py b/backend/modules/config/config_manager.py index 335ac49..c462931 100644 --- a/backend/modules/config/config_manager.py +++ b/backend/modules/config/config_manager.py @@ -285,11 +285,11 @@ def agent_mode_available(self) -> bool: description="Enable compliance level filtering for MCP servers and data sources", validation_alias=AliasChoices("FEATURE_COMPLIANCE_LEVELS_ENABLED"), ) - # DOE lab email domain check feature gate - feature_doe_lab_check_enabled: bool = Field( + # Email domain whitelist feature gate + feature_domain_whitelist_enabled: bool = Field( False, - description="Enable DOE/NNSA/DOE lab email domain restriction", - validation_alias=AliasChoices("FEATURE_DOE_LAB_CHECK_ENABLED"), + description="Enable email domain whitelist restriction (configured in domain-whitelist.json)", + validation_alias=AliasChoices("FEATURE_DOMAIN_WHITELIST_ENABLED", "FEATURE_DOE_LAB_CHECK_ENABLED"), ) # Capability tokens (for headless access to downloads/iframes) diff --git a/backend/tests/test_doe_lab_feature_flag.py b/backend/tests/test_doe_lab_feature_flag.py deleted file mode 100644 index f8e1c09..0000000 --- a/backend/tests/test_doe_lab_feature_flag.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Integration test for DOE lab middleware feature flag.""" - -import pytest -from fastapi import FastAPI, Request -from starlette.testclient import TestClient -from starlette.responses import Response -from starlette.middleware.base import BaseHTTPMiddleware - -from core.doe_lab_middleware import DOELabMiddleware - - -class SimpleAuthMiddleware(BaseHTTPMiddleware): - """Simplified auth middleware for testing that just sets user_email in state.""" - - def __init__(self, app, debug_mode: bool = False): - super().__init__(app) - self.debug_mode = debug_mode - - async def dispatch(self, request: Request, call_next) -> Response: - # Simulate setting user_email from header (like real AuthMiddleware) - email = request.headers.get("X-User-Email") - if email: - request.state.user_email = email - elif self.debug_mode: - request.state.user_email = "test@test.com" - - return await call_next(request) - - -def test_doe_middleware_not_active_when_disabled(): - """Test that DOE middleware allows non-DOE emails when not added to app.""" - # Create app without DOE middleware - app = FastAPI() - - @app.get("/api/test") - def test_endpoint(request: Request): - return {"user": getattr(request.state, "user_email", "none")} - - # Add only simple auth middleware (which sets user_email in state) - app.add_middleware(SimpleAuthMiddleware, debug_mode=True) - - # DOE middleware NOT added (simulating feature flag disabled) - - # Create client - client = TestClient(app) - - # Test with non-DOE email - should pass because DOE middleware is not active - response = client.get("/api/test", headers={"X-User-Email": "test@gmail.com"}) - assert response.status_code == 200 - assert response.json()["user"] == "test@gmail.com" - - -def test_doe_middleware_active_when_enabled(): - """Test that DOE middleware blocks non-DOE emails when added to app.""" - # Create app with DOE middleware - app = FastAPI() - - @app.get("/api/test") - def test_endpoint(request: Request): - return {"user": request.state.user_email} - - @app.get("/auth") - def auth_endpoint(): - return {"login": True} - - # Add DOE middleware first (will run second after auth) - app.add_middleware(DOELabMiddleware) - - # Add auth middleware second (will run first, setting email) - app.add_middleware(SimpleAuthMiddleware, debug_mode=True) - - # Create client - client = TestClient(app) - - # Test with non-DOE email - should be rejected - response = client.get("/api/test", headers={"X-User-Email": "test@gmail.com"}) - assert response.status_code == 403 - assert "Access restricted" in response.json()["detail"] - - # Test with valid DOE email - should pass - response = client.get("/api/test", headers={"X-User-Email": "test@sandia.gov"}) - assert response.status_code == 200 - assert response.json()["user"] == "test@sandia.gov" - - -def test_middleware_ordering_auth_before_doe(): - """Test that auth middleware must run before DOELabMiddleware.""" - # This test verifies the correct middleware ordering - app = FastAPI() - - @app.get("/api/test") - def test_endpoint(request: Request): - return {"user": request.state.user_email} - - # Add DOE first, Auth second (so Auth runs first in request flow) - app.add_middleware(DOELabMiddleware) - app.add_middleware(SimpleAuthMiddleware, debug_mode=True) - - client = TestClient(app) - - # Auth middleware should set the email, then DOELabMiddleware should check it - response = client.get("/api/test", headers={"X-User-Email": "test@lanl.gov"}) - assert response.status_code == 200 - assert response.json()["user"] == "test@lanl.gov" - - diff --git a/backend/tests/test_doe_lab_middleware.py b/backend/tests/test_doe_lab_middleware.py deleted file mode 100644 index 2223bd4..0000000 --- a/backend/tests/test_doe_lab_middleware.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Tests for DOE lab email domain middleware.""" - -import pytest -from fastapi import FastAPI, Request -from starlette.testclient import TestClient - -from core.doe_lab_middleware import DOELabMiddleware - - -@pytest.fixture -def app(): - """Create a test FastAPI app with DOE middleware.""" - app = FastAPI() - - @app.get("/api/test") - def api_test(request: Request): - return {"user": request.state.user_email} - - @app.get("/test") - def test(request: Request): - return {"user": request.state.user_email} - - @app.get("/auth") - def auth(): - return {"login": True} - - @app.get("/api/health") - def health(): - return {"status": "ok"} - - # Add DOE middleware - app.add_middleware(DOELabMiddleware) - - return app - - -@pytest.fixture -def client(app): - """Create a test client.""" - return TestClient(app) - - -class TestDOELabMiddleware: - """Test DOE lab middleware domain validation.""" - - # Valid DOE/lab email domains - @pytest.mark.parametrize("email", [ - "user@doe.gov", - "user@nnsa.doe.gov", - "user@hq.doe.gov", - "user@anl.gov", - "user@bnl.gov", - "user@lanl.gov", - "user@llnl.gov", - "user@sandia.gov", - "user@ornl.gov", - "user@pnnl.gov", - "user@lbl.gov", - "user@nrel.gov", - "user@stanford.edu", # SLAC - "user@jlab.org", - "user@pppl.gov", - # Subdomain tests - "user@sub.sandia.gov", - "user@mail.doe.gov", - "user@dept.lanl.gov", - ]) - def test_valid_doe_emails_allowed(self, client, email): - """Test that valid DOE/lab emails are allowed through.""" - # Mock the request state with email (normally set by AuthMiddleware) - def add_user_email(request, call_next): - request.state.user_email = email - return call_next(request) - - # Inject the email into request state - response = client.get( - "/api/test", - headers={"X-User-Email": email} - ) - # Since we're not actually setting request.state in the test, - # we'll get through the middleware but fail at the endpoint level. - # Let's test differently - by directly checking middleware logic. - - def test_valid_doe_email_via_state(self, app): - """Test valid DOE email passes through middleware.""" - from starlette.middleware.base import RequestResponseEndpoint - from starlette.requests import Request - from starlette.responses import Response - - middleware = DOELabMiddleware(app) - - # Create a mock request with state - async def call_next(request): - return Response("OK", status_code=200) - - # Test with valid email - async def test_request(): - from starlette.datastructures import URL - scope = { - "type": "http", - "method": "GET", - "path": "/api/test", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - request.state.user_email = "test@sandia.gov" - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 200 - - import asyncio - asyncio.run(test_request()) - - # Invalid email domains - @pytest.mark.parametrize("email", [ - "user@gmail.com", - "user@yahoo.com", - "user@example.com", - "user@company.com", - "user@malicious.com", - # Similar but not exact matches - "user@fakedoe.gov", - "user@sandia.com", # Wrong TLD - "user@doe.org", # Wrong TLD - ]) - def test_invalid_emails_rejected(self, email): - """Test that non-DOE emails are rejected.""" - from starlette.requests import Request - from starlette.responses import Response - - app = FastAPI() - middleware = DOELabMiddleware(app) - - async def call_next(request): - return Response("OK", status_code=200) - - async def test_request(): - scope = { - "type": "http", - "method": "GET", - "path": "/api/test", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - request.state.user_email = email - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 403 - assert b"Access restricted" in response.body - - import asyncio - asyncio.run(test_request()) - - def test_missing_email_rejected(self): - """Test that requests without email are rejected.""" - from starlette.requests import Request - from starlette.responses import Response - - app = FastAPI() - middleware = DOELabMiddleware(app) - - async def call_next(request): - return Response("OK", status_code=200) - - async def test_request(): - scope = { - "type": "http", - "method": "GET", - "path": "/api/test", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - # No email set in state - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 403 - - import asyncio - asyncio.run(test_request()) - - def test_health_endpoint_bypassed(self): - """Test that /api/health endpoint bypasses DOE check.""" - from starlette.requests import Request - from starlette.responses import Response - - app = FastAPI() - middleware = DOELabMiddleware(app) - - async def call_next(request): - return Response("OK", status_code=200) - - async def test_request(): - scope = { - "type": "http", - "method": "GET", - "path": "/api/health", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - # No email - should still pass for health check - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 200 - - import asyncio - asyncio.run(test_request()) - - def test_auth_endpoint_bypassed(self): - """Test that /auth endpoint bypasses DOE check.""" - from starlette.requests import Request - from starlette.responses import Response - - app = FastAPI() - middleware = DOELabMiddleware(app, auth_redirect_url="/auth") - - async def call_next(request): - return Response("OK", status_code=200) - - async def test_request(): - scope = { - "type": "http", - "method": "GET", - "path": "/auth", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - # No email - should still pass for auth endpoint - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 200 - - import asyncio - asyncio.run(test_request()) - - def test_api_endpoint_returns_json_error(self): - """Test that API endpoints get JSON error response.""" - from starlette.requests import Request - from starlette.responses import Response - - app = FastAPI() - middleware = DOELabMiddleware(app) - - async def call_next(request): - return Response("OK", status_code=200) - - async def test_request(): - scope = { - "type": "http", - "method": "GET", - "path": "/api/something", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - request.state.user_email = "bad@gmail.com" - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 403 - assert response.headers["content-type"] == "application/json" - - import asyncio - asyncio.run(test_request()) - - def test_non_api_endpoint_redirects(self): - """Test that non-API endpoints get redirected on failure.""" - from starlette.requests import Request - from starlette.responses import Response - - app = FastAPI() - middleware = DOELabMiddleware(app, auth_redirect_url="/custom-auth") - - async def call_next(request): - return Response("OK", status_code=200) - - async def test_request(): - scope = { - "type": "http", - "method": "GET", - "path": "/something", - "query_string": b"", - "headers": [], - "state": {}, - } - request = Request(scope) - request.state.user_email = "bad@gmail.com" - - response = await middleware.dispatch(request, call_next) - assert response.status_code == 302 - assert response.headers["location"] == "/custom-auth" - - import asyncio - asyncio.run(test_request()) - - def test_is_doe_domain_direct_match(self): - """Test _is_doe_domain method with direct matches.""" - app = FastAPI() - middleware = DOELabMiddleware(app) - - assert middleware._is_doe_domain("sandia.gov") is True - assert middleware._is_doe_domain("doe.gov") is True - assert middleware._is_doe_domain("lanl.gov") is True - assert middleware._is_doe_domain("llnl.gov") is True - - def test_is_doe_domain_subdomain_match(self): - """Test _is_doe_domain method with subdomain matches.""" - app = FastAPI() - middleware = DOELabMiddleware(app) - - assert middleware._is_doe_domain("mail.sandia.gov") is True - assert middleware._is_doe_domain("sub.doe.gov") is True - assert middleware._is_doe_domain("dept.lanl.gov") is True - assert middleware._is_doe_domain("team.llnl.gov") is True - - def test_is_doe_domain_invalid(self): - """Test _is_doe_domain method with invalid domains.""" - app = FastAPI() - middleware = DOELabMiddleware(app) - - assert middleware._is_doe_domain("gmail.com") is False - assert middleware._is_doe_domain("example.com") is False - assert middleware._is_doe_domain("sandia.com") is False # Wrong TLD - assert middleware._is_doe_domain("doe.org") is False # Wrong TLD - assert middleware._is_doe_domain("fakedoe.gov") is False diff --git a/backend/tests/test_domain_whitelist.py b/backend/tests/test_domain_whitelist.py new file mode 100644 index 0000000..da5cee7 --- /dev/null +++ b/backend/tests/test_domain_whitelist.py @@ -0,0 +1,270 @@ +"""Tests for domain whitelist middleware.""" + +import json +import pytest +import tempfile +from pathlib import Path +from fastapi import FastAPI, Request +from starlette.testclient import TestClient + +from core.domain_whitelist_middleware import DomainWhitelistMiddleware +from core.domain_whitelist import DomainWhitelistManager + + +@pytest.fixture +def temp_config(): + """Create a temporary config file for testing.""" + config_data = { + "version": "1.0", + "description": "Test config", + "enabled": True, + "domains": [ + {"domain": "sandia.gov", "description": "Sandia National Labs"}, + {"domain": "doe.gov", "description": "DOE"}, + {"domain": "example.org", "description": "Example"}, + ], + "subdomain_matching": True + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + temp_path = Path(f.name) + + yield temp_path + + # Cleanup + if temp_path.exists(): + temp_path.unlink() + + +@pytest.fixture +def disabled_config(): + """Create a config file with whitelist disabled.""" + config_data = { + "version": "1.0", + "description": "Disabled config", + "enabled": False, + "domains": [ + {"domain": "sandia.gov", "description": "Sandia National Labs"}, + ], + "subdomain_matching": True + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + temp_path = Path(f.name) + + yield temp_path + + if temp_path.exists(): + temp_path.unlink() + + +class TestDomainWhitelistManager: + """Test the domain whitelist manager.""" + + def test_load_config(self, temp_config): + """Test loading configuration from file.""" + manager = DomainWhitelistManager(config_path=temp_config) + + assert manager.is_enabled() is True + assert "sandia.gov" in manager.get_domains() + assert "doe.gov" in manager.get_domains() + assert "example.org" in manager.get_domains() + assert len(manager.get_domains()) == 3 + + def test_disabled_config(self, disabled_config): + """Test that disabled config doesn't enforce whitelist.""" + manager = DomainWhitelistManager(config_path=disabled_config) + + assert manager.is_enabled() is False + # Even though disabled, should allow all + assert manager.is_domain_allowed("user@gmail.com") is True + + def test_domain_matching(self, temp_config): + """Test domain matching logic.""" + manager = DomainWhitelistManager(config_path=temp_config) + + # Exact matches + assert manager.is_domain_allowed("user@sandia.gov") is True + assert manager.is_domain_allowed("user@doe.gov") is True + + # Subdomain matches + assert manager.is_domain_allowed("user@mail.sandia.gov") is True + assert manager.is_domain_allowed("user@sub.doe.gov") is True + + # Invalid domains + assert manager.is_domain_allowed("user@gmail.com") is False + assert manager.is_domain_allowed("user@sandia.com") is False # Wrong TLD + + def test_invalid_email(self, temp_config): + """Test handling of invalid email addresses.""" + manager = DomainWhitelistManager(config_path=temp_config) + + assert manager.is_domain_allowed("notanemail") is False + assert manager.is_domain_allowed("") is False + assert manager.is_domain_allowed("no-at-sign.com") is False + + +class TestDomainWhitelistMiddleware: + """Test domain whitelist middleware.""" + + def test_middleware_with_allowed_domain(self, temp_config): + """Test that allowed domains pass through.""" + from starlette.requests import Request + from starlette.responses import Response + from starlette.middleware.base import BaseHTTPMiddleware + + app = FastAPI() + + # Monkey-patch the middleware to use our temp config + original_init = DomainWhitelistMiddleware.__init__ + def patched_init(self, app, auth_redirect_url="/auth"): + BaseHTTPMiddleware.__init__(self, app) + self.auth_redirect_url = auth_redirect_url + self.whitelist_manager = DomainWhitelistManager(config_path=temp_config) + + DomainWhitelistMiddleware.__init__ = patched_init + middleware = DomainWhitelistMiddleware(app) + DomainWhitelistMiddleware.__init__ = original_init + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/test", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = "test@sandia.gov" + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 200 + + import asyncio + asyncio.run(test_request()) + + def test_middleware_with_disallowed_domain(self, temp_config): + """Test that disallowed domains are blocked.""" + from starlette.requests import Request + from starlette.responses import Response + from starlette.middleware.base import BaseHTTPMiddleware + + app = FastAPI() + + # Monkey-patch the middleware to use our temp config + original_init = DomainWhitelistMiddleware.__init__ + def patched_init(self, app, auth_redirect_url="/auth"): + BaseHTTPMiddleware.__init__(self, app) + self.auth_redirect_url = auth_redirect_url + self.whitelist_manager = DomainWhitelistManager(config_path=temp_config) + + DomainWhitelistMiddleware.__init__ = patched_init + middleware = DomainWhitelistMiddleware(app) + DomainWhitelistMiddleware.__init__ = original_init + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/test", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = "test@gmail.com" + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 403 + + import asyncio + asyncio.run(test_request()) + + def test_middleware_disabled(self, disabled_config): + """Test that disabled config allows all domains.""" + from starlette.requests import Request + from starlette.responses import Response + from starlette.middleware.base import BaseHTTPMiddleware + + app = FastAPI() + + # Monkey-patch the middleware to use our disabled config + original_init = DomainWhitelistMiddleware.__init__ + def patched_init(self, app, auth_redirect_url="/auth"): + BaseHTTPMiddleware.__init__(self, app) + self.auth_redirect_url = auth_redirect_url + self.whitelist_manager = DomainWhitelistManager(config_path=disabled_config) + + DomainWhitelistMiddleware.__init__ = patched_init + middleware = DomainWhitelistMiddleware(app) + DomainWhitelistMiddleware.__init__ = original_init + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/test", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + request.state.user_email = "test@gmail.com" + + # Should pass even though gmail.com is not in whitelist + response = await middleware.dispatch(request, call_next) + assert response.status_code == 200 + + import asyncio + asyncio.run(test_request()) + + def test_health_endpoint_bypass(self, temp_config): + """Test that health endpoint bypasses whitelist check.""" + from starlette.requests import Request + from starlette.responses import Response + from starlette.middleware.base import BaseHTTPMiddleware + + app = FastAPI() + + original_init = DomainWhitelistMiddleware.__init__ + def patched_init(self, app, auth_redirect_url="/auth"): + BaseHTTPMiddleware.__init__(self, app) + self.auth_redirect_url = auth_redirect_url + self.whitelist_manager = DomainWhitelistManager(config_path=temp_config) + + DomainWhitelistMiddleware.__init__ = patched_init + middleware = DomainWhitelistMiddleware(app) + DomainWhitelistMiddleware.__init__ = original_init + + async def call_next(request): + return Response("OK", status_code=200) + + async def test_request(): + scope = { + "type": "http", + "method": "GET", + "path": "/api/health", + "query_string": b"", + "headers": [], + "state": {}, + } + request = Request(scope) + # No email - should still pass for health check + + response = await middleware.dispatch(request, call_next) + assert response.status_code == 200 + + import asyncio + asyncio.run(test_request()) diff --git a/config/defaults/domain-whitelist-example.json b/config/defaults/domain-whitelist-example.json new file mode 100644 index 0000000..6a52cb0 --- /dev/null +++ b/config/defaults/domain-whitelist-example.json @@ -0,0 +1,19 @@ +{ + "version": "1.0", + "description": "Example custom domain whitelist configuration. Copy this to config/overrides/domain-whitelist.json and customize as needed.", + "enabled": false, + "domains": [ + { + "domain": "example.com", + "description": "Example Corporation", + "category": "Enterprise" + }, + { + "domain": "mycompany.com", + "description": "My Company", + "category": "Enterprise" + } + ], + "subdomain_matching": true, + "subdomain_matching_description": "When true, subdomains are automatically allowed (e.g., mail.example.com matches example.com)" +} diff --git a/config/defaults/domain-whitelist.json b/config/defaults/domain-whitelist.json new file mode 100644 index 0000000..41b9a46 --- /dev/null +++ b/config/defaults/domain-whitelist.json @@ -0,0 +1,119 @@ +{ + "version": "1.0", + "description": "Email domain whitelist for user access control. When enabled, only users with email addresses from whitelisted domains can access the application.", + "enabled": false, + "domains": [ + { + "domain": "doe.gov", + "description": "Department of Energy", + "category": "Government - DOE HQ" + }, + { + "domain": "nnsa.doe.gov", + "description": "National Nuclear Security Administration", + "category": "Government - DOE HQ" + }, + { + "domain": "hq.doe.gov", + "description": "DOE Headquarters", + "category": "Government - DOE HQ" + }, + { + "domain": "anl.gov", + "description": "Argonne National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "bnl.gov", + "description": "Brookhaven National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "fnal.gov", + "description": "Fermi National Accelerator Laboratory", + "category": "National Laboratory" + }, + { + "domain": "inl.gov", + "description": "Idaho National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "lbl.gov", + "description": "Lawrence Berkeley National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "lanl.gov", + "description": "Los Alamos National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "llnl.gov", + "description": "Lawrence Livermore National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "ornl.gov", + "description": "Oak Ridge National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "pnnl.gov", + "description": "Pacific Northwest National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "sandia.gov", + "description": "Sandia National Laboratories", + "category": "National Laboratory" + }, + { + "domain": "srnl.doe.gov", + "description": "Savannah River National Laboratory", + "category": "National Laboratory" + }, + { + "domain": "ameslab.gov", + "description": "Ames Laboratory", + "category": "National Laboratory" + }, + { + "domain": "jlab.org", + "description": "Thomas Jefferson National Accelerator Facility", + "category": "National Laboratory" + }, + { + "domain": "princeton.edu", + "description": "Princeton Plasma Physics Laboratory", + "category": "National Laboratory" + }, + { + "domain": "slac.stanford.edu", + "description": "SLAC National Accelerator Laboratory", + "category": "National Laboratory" + }, + { + "domain": "pppl.gov", + "description": "Princeton Plasma Physics Laboratory", + "category": "National Laboratory" + }, + { + "domain": "nrel.gov", + "description": "National Renewable Energy Laboratory", + "category": "National Laboratory" + }, + { + "domain": "netl.doe.gov", + "description": "National Energy Technology Laboratory", + "category": "National Laboratory" + }, + { + "domain": "stanford.edu", + "description": "Stanford University (SLAC host institution)", + "category": "University" + } + ], + "subdomain_matching": true, + "subdomain_matching_description": "When true, subdomains are automatically allowed (e.g., mail.sandia.gov matches sandia.gov)" +} From 97bca0fd03a869e989e4cd2e704a2e4dde5232a9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:42:20 +0000 Subject: [PATCH 4/8] Add documentation for domain whitelist configuration Co-authored-by: garland3 <1162675+garland3@users.noreply.github.com> --- config/defaults/DOMAIN_WHITELIST_README.md | 154 +++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 config/defaults/DOMAIN_WHITELIST_README.md diff --git a/config/defaults/DOMAIN_WHITELIST_README.md b/config/defaults/DOMAIN_WHITELIST_README.md new file mode 100644 index 0000000..8dff517 --- /dev/null +++ b/config/defaults/DOMAIN_WHITELIST_README.md @@ -0,0 +1,154 @@ +# Email Domain Whitelist Configuration + +This configuration controls which email domains are allowed to access the application. + +## Overview + +The domain whitelist feature allows you to restrict access to users with email addresses from specific domains. This is useful for: +- Restricting access to government organizations (DOE, NNSA, national labs) +- Limiting access to specific companies or institutions +- Implementing multi-tenant access control + +## Configuration Files + +### Default Configuration +Located at: `config/defaults/domain-whitelist.json` + +Contains DOE and national laboratory domains as an example. This file should not be modified directly. + +### Custom Configuration +To customize domains, create: `config/overrides/domain-whitelist.json` + +The override file takes precedence over the default configuration. + +## Configuration Format + +```json +{ + "version": "1.0", + "description": "Your description here", + "enabled": true, + "domains": [ + { + "domain": "example.com", + "description": "Example Corporation", + "category": "Enterprise" + }, + { + "domain": "another-domain.org", + "description": "Another Organization", + "category": "Partner" + } + ], + "subdomain_matching": true +} +``` + +### Fields + +- **version**: Configuration schema version (currently "1.0") +- **description**: Human-readable description of this configuration +- **enabled**: Whether the whitelist is enforced (true/false) + - Note: Even if true here, must also set `FEATURE_DOMAIN_WHITELIST_ENABLED=true` in environment +- **domains**: Array of domain objects + - **domain**: The email domain (e.g., "example.com") + - **description**: Optional description + - **category**: Optional category for organization +- **subdomain_matching**: If true, subdomains are automatically allowed + - Example: If "example.com" is whitelisted and subdomain_matching is true, then "user@mail.example.com" is also allowed + +## Enabling the Feature + +1. Create your custom configuration at `config/overrides/domain-whitelist.json` +2. Set `"enabled": true` in the config file +3. Set environment variable: `FEATURE_DOMAIN_WHITELIST_ENABLED=true` +4. Restart the application + +## Example Configurations + +### Example 1: DOE National Labs (Default) +```json +{ + "enabled": true, + "domains": [ + {"domain": "doe.gov", "description": "Department of Energy"}, + {"domain": "sandia.gov", "description": "Sandia National Labs"}, + {"domain": "lanl.gov", "description": "Los Alamos National Lab"} + ], + "subdomain_matching": true +} +``` + +### Example 2: Corporate Domains +```json +{ + "enabled": true, + "domains": [ + {"domain": "mycompany.com", "description": "My Company"}, + {"domain": "partner-company.org", "description": "Trusted Partner"} + ], + "subdomain_matching": true +} +``` + +### Example 3: Educational Institutions +```json +{ + "enabled": true, + "domains": [ + {"domain": "university.edu", "description": "University"}, + {"domain": "research-institute.org", "description": "Research Institute"} + ], + "subdomain_matching": true +} +``` + +## Behavior + +### When Enabled +- Users with email addresses from whitelisted domains can access the application +- Users with other email domains receive a 403 Forbidden error (API) or redirect (UI) +- Health check endpoint (`/api/health`) bypasses the check +- Authentication endpoint bypasses the check + +### When Disabled +- All authenticated users can access the application regardless of email domain +- No domain filtering is performed + +## Subdomain Matching + +When `subdomain_matching` is `true`: +- `user@example.com` matches `example.com` ✓ +- `user@mail.example.com` matches `example.com` ✓ +- `user@dept.mail.example.com` matches `example.com` ✓ + +When `subdomain_matching` is `false`: +- `user@example.com` matches `example.com` ✓ +- `user@mail.example.com` does NOT match `example.com` ✗ + +## Troubleshooting + +### Issue: Users are being blocked unexpectedly +- Check that `enabled` is set correctly in the config file +- Verify `FEATURE_DOMAIN_WHITELIST_ENABLED` environment variable +- Check domain spelling in the config file (case-insensitive) +- Check if subdomain_matching is set as needed + +### Issue: Configuration changes not taking effect +- Restart the application after changing config files +- Verify the override file is at `config/overrides/domain-whitelist.json` +- Check application logs for config loading errors + +### Issue: Everyone can access (no filtering) +- Verify `FEATURE_DOMAIN_WHITELIST_ENABLED=true` in environment +- Check that `enabled: true` in the config file +- Restart the application after making changes + +## Logging + +The middleware logs helpful information: +- On startup: Number of domains loaded and enabled status +- On rejection: Domain that was rejected (for debugging) +- On error: Config loading errors + +Check application logs for domain whitelist messages. From abdb4387bc47d0a64b49574b281a29fc2c8c2a75 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:45:41 +0000 Subject: [PATCH 5/8] Address code review feedback: optimize and refactor Co-authored-by: garland3 <1162675+garland3@users.noreply.github.com> --- backend/core/domain_whitelist.py | 12 +++- backend/tests/test_domain_whitelist.py | 81 +++++++++----------------- config/defaults/domain-whitelist.json | 5 ++ 3 files changed, 43 insertions(+), 55 deletions(-) diff --git a/backend/core/domain_whitelist.py b/backend/core/domain_whitelist.py index e720e93..8c5d702 100644 --- a/backend/core/domain_whitelist.py +++ b/backend/core/domain_whitelist.py @@ -126,13 +126,19 @@ def is_domain_allowed(self, email: str) -> bool: domain = email.split("@", 1)[1].lower() - # Check if domain is in whitelist + # Check if domain is in whitelist (O(1) lookup) if domain in self.config.domains: return True - # Check subdomains if enabled + # Check subdomains if enabled - check each parent level if self.config.subdomain_matching: - return any(domain.endswith("." + d) for d in self.config.domains) + # Split domain and check each parent level + # e.g., for "mail.dept.sandia.gov" check: "dept.sandia.gov", "sandia.gov" + parts = domain.split(".") + for i in range(1, len(parts)): + parent_domain = ".".join(parts[i:]) + if parent_domain in self.config.domains: + return True return False diff --git a/backend/tests/test_domain_whitelist.py b/backend/tests/test_domain_whitelist.py index da5cee7..69f3507 100644 --- a/backend/tests/test_domain_whitelist.py +++ b/backend/tests/test_domain_whitelist.py @@ -106,28 +106,40 @@ def test_invalid_email(self, temp_config): assert manager.is_domain_allowed("no-at-sign.com") is False -class TestDomainWhitelistMiddleware: - """Test domain whitelist middleware.""" - - def test_middleware_with_allowed_domain(self, temp_config): - """Test that allowed domains pass through.""" - from starlette.requests import Request - from starlette.responses import Response - from starlette.middleware.base import BaseHTTPMiddleware - +@pytest.fixture +def create_middleware(): + """Factory fixture to create middleware with custom config.""" + from starlette.middleware.base import BaseHTTPMiddleware + + def _create(config_path): app = FastAPI() - # Monkey-patch the middleware to use our temp config + # Monkey-patch to use custom config original_init = DomainWhitelistMiddleware.__init__ def patched_init(self, app, auth_redirect_url="/auth"): BaseHTTPMiddleware.__init__(self, app) self.auth_redirect_url = auth_redirect_url - self.whitelist_manager = DomainWhitelistManager(config_path=temp_config) + self.whitelist_manager = DomainWhitelistManager(config_path=config_path) DomainWhitelistMiddleware.__init__ = patched_init middleware = DomainWhitelistMiddleware(app) DomainWhitelistMiddleware.__init__ = original_init + return middleware + + return _create + + +class TestDomainWhitelistMiddleware: + """Test domain whitelist middleware.""" + + def test_middleware_with_allowed_domain(self, temp_config, create_middleware): + """Test that allowed domains pass through.""" + from starlette.requests import Request + from starlette.responses import Response + + middleware = create_middleware(temp_config) + async def call_next(request): return Response("OK", status_code=200) @@ -149,24 +161,12 @@ async def test_request(): import asyncio asyncio.run(test_request()) - def test_middleware_with_disallowed_domain(self, temp_config): + def test_middleware_with_disallowed_domain(self, temp_config, create_middleware): """Test that disallowed domains are blocked.""" from starlette.requests import Request from starlette.responses import Response - from starlette.middleware.base import BaseHTTPMiddleware - - app = FastAPI() - - # Monkey-patch the middleware to use our temp config - original_init = DomainWhitelistMiddleware.__init__ - def patched_init(self, app, auth_redirect_url="/auth"): - BaseHTTPMiddleware.__init__(self, app) - self.auth_redirect_url = auth_redirect_url - self.whitelist_manager = DomainWhitelistManager(config_path=temp_config) - DomainWhitelistMiddleware.__init__ = patched_init - middleware = DomainWhitelistMiddleware(app) - DomainWhitelistMiddleware.__init__ = original_init + middleware = create_middleware(temp_config) async def call_next(request): return Response("OK", status_code=200) @@ -189,24 +189,12 @@ async def test_request(): import asyncio asyncio.run(test_request()) - def test_middleware_disabled(self, disabled_config): + def test_middleware_disabled(self, disabled_config, create_middleware): """Test that disabled config allows all domains.""" from starlette.requests import Request from starlette.responses import Response - from starlette.middleware.base import BaseHTTPMiddleware - app = FastAPI() - - # Monkey-patch the middleware to use our disabled config - original_init = DomainWhitelistMiddleware.__init__ - def patched_init(self, app, auth_redirect_url="/auth"): - BaseHTTPMiddleware.__init__(self, app) - self.auth_redirect_url = auth_redirect_url - self.whitelist_manager = DomainWhitelistManager(config_path=disabled_config) - - DomainWhitelistMiddleware.__init__ = patched_init - middleware = DomainWhitelistMiddleware(app) - DomainWhitelistMiddleware.__init__ = original_init + middleware = create_middleware(disabled_config) async def call_next(request): return Response("OK", status_code=200) @@ -230,23 +218,12 @@ async def test_request(): import asyncio asyncio.run(test_request()) - def test_health_endpoint_bypass(self, temp_config): + def test_health_endpoint_bypass(self, temp_config, create_middleware): """Test that health endpoint bypasses whitelist check.""" from starlette.requests import Request from starlette.responses import Response - from starlette.middleware.base import BaseHTTPMiddleware - - app = FastAPI() - - original_init = DomainWhitelistMiddleware.__init__ - def patched_init(self, app, auth_redirect_url="/auth"): - BaseHTTPMiddleware.__init__(self, app) - self.auth_redirect_url = auth_redirect_url - self.whitelist_manager = DomainWhitelistManager(config_path=temp_config) - DomainWhitelistMiddleware.__init__ = patched_init - middleware = DomainWhitelistMiddleware(app) - DomainWhitelistMiddleware.__init__ = original_init + middleware = create_middleware(temp_config) async def call_next(request): return Response("OK", status_code=200) diff --git a/config/defaults/domain-whitelist.json b/config/defaults/domain-whitelist.json index 41b9a46..f95c131 100644 --- a/config/defaults/domain-whitelist.json +++ b/config/defaults/domain-whitelist.json @@ -85,6 +85,11 @@ }, { "domain": "princeton.edu", + "description": "Princeton University (PPPL host institution)", + "category": "University" + }, + { + "domain": "pppl.gov", "description": "Princeton Plasma Physics Laboratory", "category": "National Laboratory" }, From b58512148ca044b4cac4a91457c63ee42c867c8a Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 26 Nov 2025 02:25:27 +0000 Subject: [PATCH 6/8] feat(docs): add email domain whitelist configuration documentation --- .../DOMAIN_WHITELIST_README.md => docs/admin/domain-whitelist.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename config/defaults/DOMAIN_WHITELIST_README.md => docs/admin/domain-whitelist.md (100%) diff --git a/config/defaults/DOMAIN_WHITELIST_README.md b/docs/admin/domain-whitelist.md similarity index 100% rename from config/defaults/DOMAIN_WHITELIST_README.md rename to docs/admin/domain-whitelist.md From 797b58e902cd6ce59223a24c3dc1f2e3dbeecf6d Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 26 Nov 2025 02:25:39 +0000 Subject: [PATCH 7/8] chore(config): remove example domain whitelist configuration file --- config/defaults/domain-whitelist-example.json | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 config/defaults/domain-whitelist-example.json diff --git a/config/defaults/domain-whitelist-example.json b/config/defaults/domain-whitelist-example.json deleted file mode 100644 index 6a52cb0..0000000 --- a/config/defaults/domain-whitelist-example.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "version": "1.0", - "description": "Example custom domain whitelist configuration. Copy this to config/overrides/domain-whitelist.json and customize as needed.", - "enabled": false, - "domains": [ - { - "domain": "example.com", - "description": "Example Corporation", - "category": "Enterprise" - }, - { - "domain": "mycompany.com", - "description": "My Company", - "category": "Enterprise" - } - ], - "subdomain_matching": true, - "subdomain_matching_description": "When true, subdomains are automatically allowed (e.g., mail.example.com matches example.com)" -} From 669044f6dc1e1962731e47cd4554431c8e26ce4d Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 26 Nov 2025 02:34:02 +0000 Subject: [PATCH 8/8] fix: remove unused imports (Request, TestClient, List) - Remove unused Request and TestClient imports from test_domain_whitelist.py - Remove unused List import from domain_whitelist.py - Addresses CodeQL static analysis warnings - All tests passing --- backend/core/domain_whitelist.py | 2 +- backend/tests/test_domain_whitelist.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/core/domain_whitelist.py b/backend/core/domain_whitelist.py index 8c5d702..2cef168 100644 --- a/backend/core/domain_whitelist.py +++ b/backend/core/domain_whitelist.py @@ -8,7 +8,7 @@ import json import logging from pathlib import Path -from typing import List, Optional, Set +from typing import Optional, Set from dataclasses import dataclass logger = logging.getLogger(__name__) diff --git a/backend/tests/test_domain_whitelist.py b/backend/tests/test_domain_whitelist.py index 69f3507..227430b 100644 --- a/backend/tests/test_domain_whitelist.py +++ b/backend/tests/test_domain_whitelist.py @@ -4,8 +4,7 @@ import pytest import tempfile from pathlib import Path -from fastapi import FastAPI, Request -from starlette.testclient import TestClient +from fastapi import FastAPI from core.domain_whitelist_middleware import DomainWhitelistMiddleware from core.domain_whitelist import DomainWhitelistManager