Secure MCP Base Classes Addressing Key Security Concerns
Below are enhanced code snippets from the MCP base classes showing how they implement critical security controls using the Template Method pattern:

1. OAuth 2.1 with Azure AD Integration

**Client-Side Token Generation**

In [None]:
# mcp_client_base.py
from msal import ConfidentialClientApplication
import jwt

class MCPBaseClient:
    def __init__(self, context_type):
        self.context_type = context_type
        self.service_url = self._discover_service()
        self.auth_client = ConfidentialClientApplication(
            client_id=os.getenv("AZURE_CLIENT_ID"),
            client_credential=os.getenv("AZURE_CLIENT_SECRET"),
            authority=f"https://login.microsoftonline.com/{os.getenv('AZURE_TENANT_ID')}"
        )

    def _get_token(self):
        """OAuth 2.1 token acquisition with PKCE and short expiry"""
        result = self.auth_client.acquire_token_for_client(
            scopes=[f"api://{self._get_audience()}/.default"],
            enable_pkce=True,
            token_expiration=300  # 5 minutes
        )
        return result["access_token"]

    def _get_audience(self):
        # Discover audience from MCP registry
        service = servicedirectory.resolve_service(self.context_type)
        return service.metadata["audience"]

**Server-Side Token Validation**



In [None]:
# mcp_server_base.py
from jwt import PyJWT, JWT
from jwt.exceptions import InvalidAudienceError

class MCPBaseServer(ABC):
    def _validate_azure_token(self, token: str):
        jwks_client = PyJWT()
        signing_key = self._get_azure_signing_key()

        try:
            # STRICT AUDIENCE VALIDATION
            payload = jwks_client.decode(
                token,
                key=signing_key,
                algorithms=["RS256"],
                audience=self.expected_audience,  # Critical for confused deputy
                options={"require": ["exp", "iat", "aud"]},
                leeway=0  # No tolerance for expired tokens
            )

            # ADDITIONAL CLAIM VALIDATION
            if "roles" not in payload or "mcp_consumer" not in payload["roles"]:
                raise PermissionError("Missing required role")

        except InvalidAudienceError:
            # Log security event
            self._log_security_event("invalid_audience", token)
            raise

2. Prompt Injection Protection
**Input Sanitization & Validation**

In [None]:
class MCPBaseServer(ABC):
    PROMPT_INJECTION_PATTERNS = [
        r"(?i)ignore previous|system:|assistant:|prompt injection",
        r"(\{|\}|\"|'|;|--|\/\*|\*\/)"  # Suspicious characters
    ]

    def _sanitize_input(self, input_data: dict) -> dict:
        """Recursive input sanitization"""
        sanitized = {}
        for key, value in input_data.items():
            if isinstance(value, str):
                for pattern in self.PROMPT_INJECTION_PATTERNS:
                    if re.search(pattern, value):
                        raise SecurityException("Potential prompt injection detected")
                sanitized[key] = html.escape(value)
            elif isinstance(value, dict):
                sanitized[key] = self._sanitize_input(value)
            else:
                sanitized[key] = value
        return sanitized

    @abstractmethod
    def _execute_query(self, sanitized_input: dict):
        """Implemented in subclasses"""
        pass

    # Template method
    def process_request(self, raw_input: dict):
        sanitized = self._sanitize_input(raw_input)
        return self._execute_query(sanitized)

**REST Server Implementation**


In [None]:
class RESTServer(MCPBaseServer):
    def _execute_query(self, input_data):
        # GenAI-generated from OpenAPI spec
        query = f"/users/{input_data['user_id']}"

        # ALLOW-LIST VALIDATION
        if not re.match(r"^\/users\/[a-zA-Z0-9-]{1,36}$", query):
            raise SecurityException("Invalid query structure")

        return requests.get(f"{self.base_url}{query}")

3. Credential Exposure Prevention
**Secret Management Integration**

In [None]:
class MCPBaseServer(ABC):
    def __init__(self):
        self.secret_client = secretmanager.SecretManagerServiceClient()

    def _get_secret(self, secret_id: str) -> str:
        """Secure secret retrieval"""
        name = f"projects/{os.getenv('GCP_PROJECT')}/secrets/{secret_id}/versions/latest"
        response = self.secret_client.access_secret_version(name=name)
        return response.payload.data.decode('UTF-8')

class JDBCServer(MCPBaseServer):
    def _db_connection(self):
        # IAM Authentication (no credentials)
        return create_engine(
            f"postgresql+pg8000://",
            creator=lambda: self.connector.connect(
                os.getenv("CLOUD_SQL_INSTANCE"),
                "pg8000",
                user=f"service-account@{os.getenv('GCP_PROJECT')}.iam"
            )
        )

4. SQL/GraphQL Injection Prevention
**Parameterized Queries**

In [None]:
class JDBCServer(MCPBaseServer):
    PREPARED_QUERIES = {
        "user_data": "SELECT * FROM users WHERE id = :user_id",
        "inventory": "SELECT * FROM products WHERE category = :category"
    }

    def _execute_query(self, input_data):
        query_name = input_data["query"]
        params = input_data["params"]

        # QUERY ALLOW-LISTING
        if query_name not in self.PREPARED_QUERIES:
            raise SecurityException("Invalid query name")

        with self.engine.connect() as conn:
            # SAFE PARAMETERIZATION
            result = conn.execute(
                text(self.PREPARED_QUERIES[query_name]),
                params
            )
            return [dict(row) for row in result]

Context Poisoning Protection
**Digital Signatures & Policy Enforcement**

Rego Policy Against Poisoning (context_policy.rego)


In [None]:
rego
package mcp.context

default allow = false

# Prevent malicious context patterns
allow {
    not contains_malicious_pattern(input.context)
    within_freshness_limits(input.timestamp)
    context_size_ok(input.context)
}

contains_malicious_pattern(ctx) {
    regex.match(`(?i)(system:|ignore previous|malicious-payload)`, ctx)
}