In [None]:
from abc import ABC, abstractmethod
from google.cloud import secretmanager, kms, servicedirectory
from opa_client.opa import OpaClient
import jwt
import datetime

class MCPBaseServer(ABC):

     PROMPT_INJECTION_PATTERNS = [
        r"(?i)ignore previous|system:|assistant:|prompt injection",
        r"(\{|\}|\"|'|;|--|\/\*|\*\/)"  # Suspicious characters

    def __init__(self, service_name, context_type):
        self.secret_client = secretmanager.SecretManagerServiceClient()
        self.kms_client = kms.KeyManagementServiceClient()
        self.opa = OpaClient("http://opa:8181/v1/data/mcp/policy")
        self.service_name = service_name
        self.context_type = context_type



    # TEMPLATE METHOD (invariant sequence)
    def process_request(self, request):
        """Main processing pipeline"""
        try:
            # 1. Authentication
            self.validate_azure_token(request)

            #2.  Sanitize Input
            self.sanitize_input(request)

            # 2. Data fetching (abstract)
            raw_data = self.fetch_data(request)

            # 3. Context building (abstract)
            context_payload = self.build_context(raw_data)

            # 4. Policy validation
            if not self.validate_policy(context_payload):
                raise PolicyViolationError("OPA policy check failed")

            # 5. Digital signing
            signed_payload = self.sign_context(context_payload)

            return signed_payload
        except Exception as e:
            self.handle_error(e)

    # COMMON IMPLEMENTATIONS
    def _validate_azure_token(self, request):
        jwks_client = PyJWT()
        signing_key = self._get_azure_signing_key()
        """get Token from request"""
        token = "token"
        try:
            # STRICT AUDIENCE VALIDATION
            request_payload = jwks_client.decode(
                token,
                key=signing_key,
                algorithms=["RS256"],
                audience=self.expected_audience,  # Critical for confused deputy
                options={"require": ["exp", "iat", "aud"]},
                leeway=0  # No tolerance for expired tokens
            )

           self.validate_authorization(self,request_payload)


        except InvalidAudienceError:
            # Log security event
            self._log_security_event("invalid_audience", token)
            raise

        except InvalidSignatureError:
            # Log security event
            self._log_security_event("invalid_signature", token)
            raise

      def _sanitize_input(self, input_data: dict) -> dict:
        """Recursive input sanitization"""
        sanitized = {}
        for key, value in input_data.items():
            if isinstance(value, str):
                for pattern in self.PROMPT_INJECTION_PATTERNS:
                    if re.search(pattern, value):
                        raise SecurityException("Potential prompt injection detected")
                sanitized[key] = html.escape(value)
            elif isinstance(value, dict):
                sanitized[key] = self._sanitize_input(value)
            else:
                sanitized[key] = value
        return sanitized

    def get_secret(self, secret_id: str) -> str:
        """Secure secret retrieval"""
        name = f"projects/{os.getenv('GCP_PROJECT')}/secrets/{secret_id}/versions/latest"
        response = self.secret_client.access_secret_version(name=name)
        return response.payload.data.decode('UTF-8')


    def validate_policy(self, context_payload):
        """Enforce Rego policies via OPA"""
        opa_input = {
            "context": context_payload,
            "source_type": self.context_type,
            "timestamp": datetime.datetime.utcnow().isoformat()
        }
        return self.opa.check_policy(input=opa_input).get("result", {}).get("allow", False)

    def sign_context(self, context_payload):
        """Sign with GCP KMS"""
        kms_path = self.secret_client.access_secret_version(
            name="projects/my-project/secrets/KMS_KEY_PATH/versions/latest"
        ).payload.data.decode()

        response = self.kms_client.asymmetric_sign(
            name=kms_path,
            data=json.dumps(payload).encode()
        )
        return {**payload, "signature": response.signature.hex()}

    def register_service(self):
        """Register with MCP Registry (GCP Service Directory)"""
        service = {
            "name": self.service_name,
            "endpoint": f"https://{self.service_name}.run.app",
            "metadata": {
                "context_type": self.context_type,
                "auth_scheme": "oauth2.1",
                "policy_version": "v1.2"
            }
        }
        servicedirectory.RegistrationServiceClient().create_service(
            parent="projects/my-project/locations/global/namespaces/mcp",
            service=service
        )

    def handle_error(self, error):
        # Centralized error handling
        pass

    # ABSTRACT METHODS (implemented in subclasses)
    @abstractmethod
    def validate_authorization(self, request_payload):
        """ADDITIONAL CLAIM VALIDATION"""
        pass

    @abstractmethod
    def fetch_data(self, request_payload):
        """Retrieve data from source system"""
        pass

    @abstractmethod
    def build_context(self, raw_data):
        """Convert to agent-consumable JSON-LD format"""
        pass