In [None]:
from abc import ABC, abstractmethod
import re
import html
import json
import os
import datetime
import jwt
from jwt import InvalidAudienceError, InvalidSignatureError
from google.cloud import secretmanager, kms, servicedirectory
from opa_client.opa import OpaClient

class SecurityException(Exception):
    pass

class PolicyViolationError(Exception):
    pass

class MCPBaseServer(ABC):
    PROMPT_INJECTION_PATTERNS = [
        r"(?i)ignore previous|system:|assistant:|prompt injection",
        r"(\{|\}|\"|'|;|--|\\/\\*|\\*\\/)"  # Suspicious characters
    ]

    def __init__(self, service_name, context_type):
        self.secret_client = secretmanager.SecretManagerServiceClient()
        self.kms_client = kms.KeyManagementServiceClient()
        self.opa = OpaClient("http://opa:8181")
        self.service_name = service_name
        self.context_type = context_type

        # Register service during initialization
        self.register_service()

    # TEMPLATE METHOD (invariant sequence)
    def process_request(self, request):
        """Main processing pipeline"""
        try:
            # 1. Authentication
            token = self._extract_token(request)
            request_payload = self._validate_azure_token(token)

            # 2. Authorization
            self.validate_authorization(request_payload)

            # 3. Sanitize Input
            sanitized_request = self._sanitize_input(request)

            # 4. Data fetching (abstract)
            raw_data = self.fetch_data(sanitized_request)

            # 5. Context building (abstract)
            context_payload = self.build_context(raw_data)

            # 6. Policy validation
            if not self.validate_policy(context_payload):
                raise PolicyViolationError("OPA policy check failed")

            # 7. Digital signing
            signed_payload = self.sign_context(context_payload)

            return signed_payload
        except Exception as e:
            self.handle_error(e)
            raise  # Re-raise after handling

    # COMMON IMPLEMENTATIONS
    def _extract_token(self, request):
        """Extract token from Authorization header"""
        auth_header = request.headers.get("Authorization", "")
        if not auth_header.startswith("Bearer "):
            raise ValueError("Invalid authorization header")
        return auth_header.split(" ")[1]

    def _validate_azure_token(self, token):
        """Validate Azure AD token with strict audience validation"""
        jwks_uri = "https://login.microsoftonline.com/common/discovery/keys"
        jwks_client = jwt.PyJWKClient(jwks_uri)
        signing_key = jwks_client.get_signing_key_from_jwt(token)

        try:
            # Strict audience validation
            return jwt.decode(
                token,
                key=signing_key.key,
                algorithms=["RS256"],
                audience=self.get_expected_audience(),
                options={"require": ["exp", "iat", "aud"]},
                leeway=0  # No tolerance for expired tokens
            )
        except InvalidAudienceError:
            self._log_security_event("invalid_audience", token)
            raise
        except InvalidSignatureError:
            self._log_security_event("invalid_signature", token)
            raise

    def _sanitize_input(self, input_data: dict) -> dict:
        """Recursive input sanitization"""
        sanitized = {}
        for key, value in input_data.items():
            if isinstance(value, str):
                # Check for prompt injection patterns
                for pattern in self.PROMPT_INJECTION_PATTERNS:
                    if re.search(pattern, value):
                        raise SecurityException(f"Potential prompt injection in field '{key}'")
                # HTML escape to prevent XSS
                sanitized[key] = html.escape(value)
            elif isinstance(value, dict):
                sanitized[key] = self._sanitize_input(value)
            elif isinstance(value, list):
                sanitized[key] = [self._sanitize_input(item) if isinstance(item, dict) else
                                  html.escape(item) if isinstance(item, str) else item
                                  for item in value]
            else:
                sanitized[key] = value
        return sanitized

    def get_secret(self, secret_id: str) -> str:
        """Secure secret retrieval from GCP Secret Manager"""
        name = f"projects/{os.getenv('GCP_PROJECT')}/secrets/{secret_id}/versions/latest"
        response = self.secret_client.access_secret_version(name=name)
        return response.payload.data.decode('UTF-8')

    def validate_policy(self, context_payload):
        """Enforce Rego policies via OPA"""
        opa_input = {
            "input": {
                "context": context_payload,
                "context_type": self.context_type,
                "timestamp": datetime.datetime.utcnow().isoformat()
            }
        }
        result = self.opa.check_policy("mcp/policy/allow", data=opa_input)
        return result.get("result", False)

    def sign_context(self, context_payload):
        """Sign with GCP KMS"""
        kms_path = self.get_secret("KMS_KEY_PATH")
        data = json.dumps(context_payload).encode("utf-8")

        response = self.kms_client.asymmetric_sign(
            name=kms_path,
            data=data,
            data_crc32c=kms.Crc32c().compute(data)
        )
        return {
            "context": context_payload,
            "signature": response.signature.hex(),
            "algorithm": response.algorithm.name
        }

    def register_service(self):
        """Register with MCP Registry (GCP Service Directory)"""
        client = servicedirectory.RegistrationServiceClient()
        parent = f"projects/{os.getenv('GCP_PROJECT')}/locations/global/namespaces/mcp"

        service = servicedirectory.Service(
            name=client.service_path(
                os.getenv('GCP_PROJECT'),
                "global",
                "mcp",
                self.service_name
            ),
            endpoints=[
                servicedirectory.Endpoint(
                    address=f"{self.service_name}.run.app",
                    port=443,
                    metadata={
                        "context_type": self.context_type,
                        "auth_scheme": "oauth2.1",
                        "policy_version": "v1.2"
                    }
                )
            ],
            metadata={"service_type": "mcp"}
        )

        client.create_service(parent=parent, service=service, service_id=self.service_name)

    def handle_error(self, error):
        """Centralized error handling"""
        # Implement logging, metrics, etc.
        print(f"Error processing request: {type(error).__name__} - {str(error)}")
        # Example: Send to error monitoring service
        # error_client.report(error)

    def _log_security_event(self, event_type: str, token: str):
        """Log security events for auditing"""
        print(f"Security alert: {event_type} detected in token {token[:6]}...")

    # ABSTRACT METHODS (implemented in subclasses)
    @abstractmethod
    def get_expected_audience(self) -> str:
        """Return the expected audience for token validation"""
        pass

    @abstractmethod
    def validate_authorization(self, request_payload: dict):
        """Perform additional claim validation"""
        pass

    @abstractmethod
    def fetch_data(self, request_payload: dict):
        """Retrieve data from source system"""
        pass

    @abstractmethod
    def build_context(self, raw_data) -> dict:
        """Convert to agent-consumable JSON-LD format"""
        pass