-
Notifications
You must be signed in to change notification settings - Fork 5
Add generic email domain whitelist middleware with configuration file #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d9c173e
f0988a3
2092ede
97bca0f
abdb438
b585121
797b58e
669044f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| """ | ||
| Domain whitelist management for email access control. | ||
| Loads domain whitelist definitions from domain-whitelist.json and provides | ||
| validation for user email domains. | ||
| """ | ||
|
|
||
| import json | ||
| import logging | ||
| from pathlib import Path | ||
| from typing import Optional, Set | ||
| from dataclasses import dataclass | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class DomainWhitelistConfig: | ||
| """Configuration for domain whitelist.""" | ||
| enabled: bool | ||
| domains: Set[str] | ||
| subdomain_matching: bool | ||
| version: str | ||
| description: str | ||
|
|
||
|
|
||
| class DomainWhitelistManager: | ||
| """Manages domain whitelist configuration and validation.""" | ||
|
|
||
| def __init__(self, config_path: Optional[Path] = None): | ||
| """Initialize the domain whitelist manager. | ||
| Args: | ||
| config_path: Path to domain-whitelist.json. If None, uses default location. | ||
| """ | ||
| self.config: Optional[DomainWhitelistConfig] = None | ||
|
|
||
| if config_path is None: | ||
| # Try to find config in standard locations | ||
| backend_root = Path(__file__).parent.parent | ||
| project_root = backend_root.parent | ||
|
|
||
| search_paths = [ | ||
| project_root / "config" / "overrides" / "domain-whitelist.json", | ||
| project_root / "config" / "defaults" / "domain-whitelist.json", | ||
| backend_root / "configfilesadmin" / "domain-whitelist.json", | ||
| backend_root / "configfiles" / "domain-whitelist.json", | ||
| ] | ||
|
|
||
| for path in search_paths: | ||
| if path.exists(): | ||
| config_path = path | ||
| break | ||
|
|
||
| if config_path and config_path.exists(): | ||
| self._load_config(config_path) | ||
| else: | ||
| logger.warning("No domain-whitelist.json found, domain whitelist disabled") | ||
| self.config = DomainWhitelistConfig( | ||
| enabled=False, | ||
| domains=set(), | ||
| subdomain_matching=True, | ||
| version="1.0", | ||
| description="No config loaded" | ||
| ) | ||
|
|
||
| def _load_config(self, config_path: Path): | ||
| """Load domain whitelist configuration from JSON file.""" | ||
| try: | ||
| with open(config_path, 'r', encoding='utf-8') as f: | ||
| config_data = json.load(f) | ||
|
|
||
| # Extract domains from the list of domain objects | ||
| domains = set() | ||
| for domain_entry in config_data.get('domains', []): | ||
| if isinstance(domain_entry, dict): | ||
| domains.add(domain_entry.get('domain', '').lower()) | ||
| elif isinstance(domain_entry, str): | ||
| domains.add(domain_entry.lower()) | ||
|
|
||
| self.config = DomainWhitelistConfig( | ||
| enabled=config_data.get('enabled', False), | ||
| domains=domains, | ||
| subdomain_matching=config_data.get('subdomain_matching', True), | ||
| version=config_data.get('version', '1.0'), | ||
| description=config_data.get('description', '') | ||
| ) | ||
|
|
||
| logger.info(f"Loaded {len(self.config.domains)} domains from {config_path}") | ||
| logger.debug(f"Domain whitelist enabled: {self.config.enabled}") | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error loading domain-whitelist.json: {e}") | ||
| # Use disabled config on error | ||
| self.config = DomainWhitelistConfig( | ||
| enabled=False, | ||
| domains=set(), | ||
| subdomain_matching=True, | ||
| version="1.0", | ||
| description="Error loading config" | ||
| ) | ||
|
|
||
| def is_enabled(self) -> bool: | ||
| """Check if domain whitelist is enabled. | ||
| Returns: | ||
| True if enabled, False otherwise | ||
| """ | ||
| return self.config is not None and self.config.enabled | ||
|
|
||
| def is_domain_allowed(self, email: str) -> bool: | ||
| """Check if an email address is from an allowed domain. | ||
| Args: | ||
| email: Email address to validate | ||
| Returns: | ||
| True if domain is allowed, False otherwise | ||
| """ | ||
| if not self.config or not self.config.enabled: | ||
| # If not enabled or no config, allow all | ||
| return True | ||
|
|
||
| if not email or "@" not in email: | ||
| return False | ||
|
|
||
| domain = email.split("@", 1)[1].lower() | ||
|
|
||
| # Check if domain is in whitelist (O(1) lookup) | ||
| if domain in self.config.domains: | ||
| return True | ||
|
|
||
| # Check subdomains if enabled - check each parent level | ||
| if self.config.subdomain_matching: | ||
| # Split domain and check each parent level | ||
| # e.g., for "mail.dept.sandia.gov" check: "dept.sandia.gov", "sandia.gov" | ||
| parts = domain.split(".") | ||
| for i in range(1, len(parts)): | ||
| parent_domain = ".".join(parts[i:]) | ||
| if parent_domain in self.config.domains: | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def get_domains(self) -> Set[str]: | ||
| """Get the set of whitelisted domains. | ||
| Returns: | ||
| Set of allowed domains | ||
| """ | ||
| return self.config.domains if self.config else set() | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,88 @@ | ||||||
| """Email domain whitelist validation middleware. | ||||||
| This middleware enforces that users must have email addresses from whitelisted | ||||||
| domains. Configuration is loaded from domain-whitelist.json and can be | ||||||
| enabled/disabled via the FEATURE_DOMAIN_WHITELIST_ENABLED feature flag. | ||||||
| """ | ||||||
|
|
||||||
| import logging | ||||||
| from fastapi import Request | ||||||
| from starlette.middleware.base import BaseHTTPMiddleware | ||||||
| from starlette.responses import JSONResponse, RedirectResponse, Response | ||||||
|
|
||||||
| from core.domain_whitelist import DomainWhitelistManager | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
|
|
||||||
| class DomainWhitelistMiddleware(BaseHTTPMiddleware): | ||||||
| """Middleware to enforce email domain whitelist restrictions.""" | ||||||
|
|
||||||
| def __init__(self, app, auth_redirect_url: str = "/auth"): | ||||||
| """Initialize domain whitelist middleware. | ||||||
| Args: | ||||||
| app: ASGI application | ||||||
| auth_redirect_url: URL to redirect to on auth failure (default: /auth) | ||||||
| """ | ||||||
| super().__init__(app) | ||||||
| self.auth_redirect_url = auth_redirect_url | ||||||
| self.whitelist_manager = DomainWhitelistManager() | ||||||
|
|
||||||
| if self.whitelist_manager.is_enabled(): | ||||||
| logger.info(f"Domain whitelist enabled with {len(self.whitelist_manager.get_domains())} domains") | ||||||
| else: | ||||||
| logger.info("Domain whitelist disabled") | ||||||
|
|
||||||
| async def dispatch(self, request: Request, call_next) -> Response: | ||||||
| """Check if user email is from a whitelisted domain. | ||||||
| Args: | ||||||
| request: Incoming HTTP request | ||||||
| call_next: Next middleware/handler in chain | ||||||
| Returns: | ||||||
| Response from next handler if authorized, or 403/redirect if not | ||||||
| """ | ||||||
| # Skip check for health endpoint and auth redirect endpoint | ||||||
| if request.url.path == '/api/health' or request.url.path == self.auth_redirect_url: | ||||||
| return await call_next(request) | ||||||
|
|
||||||
| # If whitelist is not enabled in config, allow all | ||||||
| if not self.whitelist_manager.is_enabled(): | ||||||
| return await call_next(request) | ||||||
|
|
||||||
| # Get email from request state (set by AuthMiddleware) | ||||||
| email = getattr(request.state, "user_email", None) | ||||||
|
|
||||||
| if not email or "@" not in email: | ||||||
| logger.warning("Domain whitelist check failed: missing or invalid email") | ||||||
| return self._unauthorized_response(request, "User email required") | ||||||
|
|
||||||
| # Check if domain is allowed | ||||||
| if not self.whitelist_manager.is_domain_allowed(email): | ||||||
| domain = email.split("@", 1)[1].lower() | ||||||
| logger.warning(f"Domain whitelist check failed: unauthorized domain {domain}") | ||||||
|
||||||
| logger.warning(f"Domain whitelist check failed: unauthorized domain {domain}") | |
| logger.warning("Domain whitelist check failed: unauthorized domain") |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |||||
| from core.middleware import AuthMiddleware | ||||||
| from core.rate_limit_middleware import RateLimitMiddleware | ||||||
| from core.security_headers_middleware import SecurityHeadersMiddleware | ||||||
| from core.domain_whitelist_middleware import DomainWhitelistMiddleware | ||||||
| from core.otel_config import setup_opentelemetry | ||||||
| from core.utils import sanitize_for_logging | ||||||
| from core.auth import get_user_from_header | ||||||
|
|
@@ -132,6 +133,12 @@ async def lifespan(app: FastAPI): | |||||
| """ | ||||||
| app.add_middleware(SecurityHeadersMiddleware) | ||||||
| app.add_middleware(RateLimitMiddleware) | ||||||
| # Domain whitelist check (if enabled) - add before Auth so it runs after | ||||||
|
||||||
| # Domain whitelist check (if enabled) - add before Auth so it runs after | |
| # Domain whitelist check (if enabled) - add before AuthMiddleware registration (executes after AuthMiddleware in request flow) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential bug: Empty domain strings can be added to the whitelist set. If a domain entry has an empty
domainfield (line 77), it will add an empty string to thedomainsset. This could cause unexpected behavior where emails likeuser@(malformed) might incorrectly match against an empty domain.Add validation to skip empty domains: