Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ FEATURE_FILES_PANEL_ENABLED=true # Uploaded/session files panel
FEATURE_CHAT_HISTORY_ENABLED=false # Previous chat history list
FEATURE_COMPLIANCE_LEVELS_ENABLED=false # Compliance level filtering for MCP servers and data sources
FEATURE_SPLASH_SCREEN_ENABLED=false # Startup splash screen for displaying policies and information
FEATURE_DOMAIN_WHITELIST_ENABLED=false # Restrict access to whitelisted email domains (config/defaults/domain-whitelist.json)

# (Adjust above to stage rollouts. For a bare-bones chat set them all to false.)

Expand Down
151 changes: 151 additions & 0 deletions backend/core/domain_whitelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Domain whitelist management for email access control.
Loads domain whitelist definitions from domain-whitelist.json and provides
validation for user email domains.
"""

import json
import logging
from pathlib import Path
from typing import Optional, Set
from dataclasses import dataclass

logger = logging.getLogger(__name__)


@dataclass
class DomainWhitelistConfig:
"""Configuration for domain whitelist."""
enabled: bool
domains: Set[str]
subdomain_matching: bool
version: str
description: str


class DomainWhitelistManager:
"""Manages domain whitelist configuration and validation."""

def __init__(self, config_path: Optional[Path] = None):
"""Initialize the domain whitelist manager.
Args:
config_path: Path to domain-whitelist.json. If None, uses default location.
"""
self.config: Optional[DomainWhitelistConfig] = None

if config_path is None:
# Try to find config in standard locations
backend_root = Path(__file__).parent.parent
project_root = backend_root.parent

search_paths = [
project_root / "config" / "overrides" / "domain-whitelist.json",
project_root / "config" / "defaults" / "domain-whitelist.json",
backend_root / "configfilesadmin" / "domain-whitelist.json",
backend_root / "configfiles" / "domain-whitelist.json",
]

for path in search_paths:
if path.exists():
config_path = path
break

if config_path and config_path.exists():
self._load_config(config_path)
else:
logger.warning("No domain-whitelist.json found, domain whitelist disabled")
self.config = DomainWhitelistConfig(
enabled=False,
domains=set(),
subdomain_matching=True,
version="1.0",
description="No config loaded"
)

def _load_config(self, config_path: Path):
"""Load domain whitelist configuration from JSON file."""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)

# Extract domains from the list of domain objects
domains = set()
for domain_entry in config_data.get('domains', []):
if isinstance(domain_entry, dict):
domains.add(domain_entry.get('domain', '').lower())
elif isinstance(domain_entry, str):
domains.add(domain_entry.lower())
Comment on lines +77 to +79
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: Empty domain strings can be added to the whitelist set. If a domain entry has an empty domain field (line 77), it will add an empty string to the domains set. This could cause unexpected behavior where emails like user@ (malformed) might incorrectly match against an empty domain.

Add validation to skip empty domains:

for domain_entry in config_data.get('domains', []):
    if isinstance(domain_entry, dict):
        domain = domain_entry.get('domain', '').lower().strip()
        if domain:  # Only add non-empty domains
            domains.add(domain)
    elif isinstance(domain_entry, str):
        domain = domain_entry.lower().strip()
        if domain:  # Only add non-empty domains
            domains.add(domain)
Suggested change
domains.add(domain_entry.get('domain', '').lower())
elif isinstance(domain_entry, str):
domains.add(domain_entry.lower())
domain = domain_entry.get('domain', '').lower().strip()
if domain:
domains.add(domain)
elif isinstance(domain_entry, str):
domain = domain_entry.lower().strip()
if domain:
domains.add(domain)

Copilot uses AI. Check for mistakes.

self.config = DomainWhitelistConfig(
enabled=config_data.get('enabled', False),
domains=domains,
subdomain_matching=config_data.get('subdomain_matching', True),
version=config_data.get('version', '1.0'),
description=config_data.get('description', '')
)

logger.info(f"Loaded {len(self.config.domains)} domains from {config_path}")
logger.debug(f"Domain whitelist enabled: {self.config.enabled}")

except Exception as e:
logger.error(f"Error loading domain-whitelist.json: {e}")
# Use disabled config on error
self.config = DomainWhitelistConfig(
enabled=False,
domains=set(),
subdomain_matching=True,
version="1.0",
description="Error loading config"
)

def is_enabled(self) -> bool:
"""Check if domain whitelist is enabled.
Returns:
True if enabled, False otherwise
"""
return self.config is not None and self.config.enabled

def is_domain_allowed(self, email: str) -> bool:
"""Check if an email address is from an allowed domain.
Args:
email: Email address to validate
Returns:
True if domain is allowed, False otherwise
"""
if not self.config or not self.config.enabled:
# If not enabled or no config, allow all
return True

if not email or "@" not in email:
return False

domain = email.split("@", 1)[1].lower()

# Check if domain is in whitelist (O(1) lookup)
if domain in self.config.domains:
return True

# Check subdomains if enabled - check each parent level
if self.config.subdomain_matching:
# Split domain and check each parent level
# e.g., for "mail.dept.sandia.gov" check: "dept.sandia.gov", "sandia.gov"
parts = domain.split(".")
for i in range(1, len(parts)):
parent_domain = ".".join(parts[i:])
if parent_domain in self.config.domains:
return True

return False

def get_domains(self) -> Set[str]:
"""Get the set of whitelisted domains.
Returns:
Set of allowed domains
"""
return self.config.domains if self.config else set()
88 changes: 88 additions & 0 deletions backend/core/domain_whitelist_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Email domain whitelist validation middleware.
This middleware enforces that users must have email addresses from whitelisted
domains. Configuration is loaded from domain-whitelist.json and can be
enabled/disabled via the FEATURE_DOMAIN_WHITELIST_ENABLED feature flag.
"""

import logging
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, RedirectResponse, Response

from core.domain_whitelist import DomainWhitelistManager

logger = logging.getLogger(__name__)


class DomainWhitelistMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce email domain whitelist restrictions."""

def __init__(self, app, auth_redirect_url: str = "/auth"):
"""Initialize domain whitelist middleware.
Args:
app: ASGI application
auth_redirect_url: URL to redirect to on auth failure (default: /auth)
"""
super().__init__(app)
self.auth_redirect_url = auth_redirect_url
self.whitelist_manager = DomainWhitelistManager()

if self.whitelist_manager.is_enabled():
logger.info(f"Domain whitelist enabled with {len(self.whitelist_manager.get_domains())} domains")
else:
logger.info("Domain whitelist disabled")

async def dispatch(self, request: Request, call_next) -> Response:
"""Check if user email is from a whitelisted domain.
Args:
request: Incoming HTTP request
call_next: Next middleware/handler in chain
Returns:
Response from next handler if authorized, or 403/redirect if not
"""
# Skip check for health endpoint and auth redirect endpoint
if request.url.path == '/api/health' or request.url.path == self.auth_redirect_url:
return await call_next(request)

# If whitelist is not enabled in config, allow all
if not self.whitelist_manager.is_enabled():
return await call_next(request)

# Get email from request state (set by AuthMiddleware)
email = getattr(request.state, "user_email", None)

if not email or "@" not in email:
logger.warning("Domain whitelist check failed: missing or invalid email")
return self._unauthorized_response(request, "User email required")

# Check if domain is allowed
if not self.whitelist_manager.is_domain_allowed(email):
domain = email.split("@", 1)[1].lower()
logger.warning(f"Domain whitelist check failed: unauthorized domain {domain}")
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Potential security information disclosure: The middleware logs the rejected domain at WARNING level when a domain is not whitelisted. This could expose information about which domains are attempting to access the system, which might be sensitive in some security contexts.

Consider either:

  1. Lowering this to DEBUG level to reduce exposure in production logs
  2. Using a more generic log message without the specific domain
  3. Ensuring logs are properly secured if this information is intentionally captured for security monitoring

Example:

logger.debug(f"Domain whitelist check failed: unauthorized domain {domain}")
# or
logger.warning("Domain whitelist check failed: unauthorized domain")
Suggested change
logger.warning(f"Domain whitelist check failed: unauthorized domain {domain}")
logger.warning("Domain whitelist check failed: unauthorized domain")

Copilot uses AI. Check for mistakes.
return self._unauthorized_response(
request,
"Access restricted to whitelisted domains"
)

return await call_next(request)

def _unauthorized_response(self, request: Request, detail: str) -> Response:
"""Return appropriate unauthorized response based on endpoint type.
Args:
request: Incoming HTTP request
detail: Error detail message
Returns:
JSONResponse for API endpoints, RedirectResponse for others
"""
if request.url.path.startswith('/api/'):
return JSONResponse(
status_code=403,
content={"detail": detail}
)
return RedirectResponse(url=self.auth_redirect_url, status_code=302)
7 changes: 7 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from core.middleware import AuthMiddleware
from core.rate_limit_middleware import RateLimitMiddleware
from core.security_headers_middleware import SecurityHeadersMiddleware
from core.domain_whitelist_middleware import DomainWhitelistMiddleware
from core.otel_config import setup_opentelemetry
from core.utils import sanitize_for_logging
from core.auth import get_user_from_header
Expand Down Expand Up @@ -132,6 +133,12 @@ async def lifespan(app: FastAPI):
"""
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RateLimitMiddleware)
# Domain whitelist check (if enabled) - add before Auth so it runs after
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The comment "add before Auth so it runs after" is potentially confusing. In FastAPI/Starlette middleware, when you call app.add_middleware(), middlewares are registered in reverse order of execution. The middleware added last runs first in the request processing chain.

In this code:

  1. SecurityHeadersMiddleware is added first → runs last (modifies response headers)
  2. RateLimitMiddleware is added second → runs third
  3. DomainWhitelistMiddleware is added third → runs second
  4. AuthMiddleware is added last → runs first

So DomainWhitelistMiddleware actually runs after AuthMiddleware in the request flow, which is correct (it needs request.state.user_email set by AuthMiddleware). However, the comment says "add before Auth so it runs after" which might be confusing since we're adding it before the AuthMiddleware registration.

Consider rewording to: "Add DomainWhitelistMiddleware before AuthMiddleware registration (executes after AuthMiddleware in request flow)" for clarity.

Suggested change
# Domain whitelist check (if enabled) - add before Auth so it runs after
# Domain whitelist check (if enabled) - add before AuthMiddleware registration (executes after AuthMiddleware in request flow)

Copilot uses AI. Check for mistakes.
if config.app_settings.feature_domain_whitelist_enabled:
app.add_middleware(
DomainWhitelistMiddleware,
auth_redirect_url=config.app_settings.auth_redirect_url
)
app.add_middleware(
AuthMiddleware,
debug_mode=config.app_settings.debug_mode,
Expand Down
6 changes: 6 additions & 0 deletions backend/modules/config/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,12 @@ def agent_mode_available(self) -> bool:
description="Enable compliance level filtering for MCP servers and data sources",
validation_alias=AliasChoices("FEATURE_COMPLIANCE_LEVELS_ENABLED"),
)
# Email domain whitelist feature gate
feature_domain_whitelist_enabled: bool = Field(
False,
description="Enable email domain whitelist restriction (configured in domain-whitelist.json)",
validation_alias=AliasChoices("FEATURE_DOMAIN_WHITELIST_ENABLED", "FEATURE_DOE_LAB_CHECK_ENABLED"),
)

# Capability tokens (for headless access to downloads/iframes)
capability_token_secret: str = ""
Expand Down
Loading
Loading