From 3d6e8a99e3adcf64ebdd9d474005002cb51b4a51 Mon Sep 17 00:00:00 2001 From: Edison-A-N Date: Tue, 21 Oct 2025 13:59:23 +0800 Subject: [PATCH 1/2] fix: prevent circular reference recursion in schema resolution --- fastapi_mcp/openapi/utils.py | 38 ++++-- tests/test_circular_reference.py | 201 +++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+), 8 deletions(-) create mode 100644 tests/test_circular_reference.py diff --git a/fastapi_mcp/openapi/utils.py b/fastapi_mcp/openapi/utils.py index 1821d57..6f5d03f 100644 --- a/fastapi_mcp/openapi/utils.py +++ b/fastapi_mcp/openapi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Set, Optional def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: @@ -16,42 +16,64 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: return param_schema.get("type", "string") -def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]: +def resolve_schema_references( + schema_part: Dict[str, Any], reference_schema: Dict[str, Any], visited_paths: Optional[Set[str]] = None +) -> Dict[str, Any]: """ - Resolve schema references in OpenAPI schemas. + Resolve schema references in OpenAPI schemas with circular reference detection. Args: schema_part: The part of the schema being processed that may contain references reference_schema: The complete schema used to resolve references from + visited_paths: Set of already visited reference paths to detect circular references Returns: The schema with references resolved """ + if visited_paths is None: + visited_paths = set() + # Make a copy to avoid modifying the input schema schema_part = schema_part.copy() # Handle $ref directly in the schema if "$ref" in schema_part: ref_path = schema_part["$ref"] + + # Check for circular reference + if ref_path in visited_paths: + # Return a placeholder schema to break the cycle + return {"type": "object", "description": "Circular reference detected", "properties": {}} + # Standard OpenAPI references are in the format "#/components/schemas/ModelName" if ref_path.startswith("#/components/schemas/"): model_name = ref_path.split("/")[-1] if "components" in reference_schema and "schemas" in reference_schema["components"]: if model_name in reference_schema["components"]["schemas"]: - # Replace with the resolved schema + # Add current path to visited paths + visited_paths.add(ref_path) + + # Get the referenced schema and resolve it recursively ref_schema = reference_schema["components"]["schemas"][model_name].copy() - # Remove the $ref key and merge with the original schema + resolved_ref = resolve_schema_references(ref_schema, reference_schema, visited_paths) + + # Remove the $ref key and merge with the resolved schema schema_part.pop("$ref") - schema_part.update(ref_schema) + schema_part.update(resolved_ref) + + # Remove from visited paths after processing + visited_paths.remove(ref_path) + return schema_part # Recursively resolve references in all dictionary values for key, value in schema_part.items(): if isinstance(value, dict): - schema_part[key] = resolve_schema_references(value, reference_schema) + schema_part[key] = resolve_schema_references(value, reference_schema, visited_paths) elif isinstance(value, list): # Only process list items that are dictionaries since only they can contain refs schema_part[key] = [ - resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value + resolve_schema_references(item, reference_schema, visited_paths) if isinstance(item, dict) else item + for item in value ] return schema_part diff --git a/tests/test_circular_reference.py b/tests/test_circular_reference.py new file mode 100644 index 0000000..5a3781d --- /dev/null +++ b/tests/test_circular_reference.py @@ -0,0 +1,201 @@ +""" +Test cases for circular reference handling in OpenAPI schema resolution. +""" + +import pytest +from fastapi import FastAPI +from pydantic import BaseModel +from typing import List, Optional +from fastapi.openapi.utils import get_openapi + +from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools +from fastapi_mcp.openapi.utils import resolve_schema_references + + +class User(BaseModel): + """User model with circular reference to Post.""" + + id: int + name: str + posts: List["Post"] = [] + + +class Post(BaseModel): + """Post model with circular reference to User.""" + + id: int + title: str + author: "User" + comments: List["Comment"] = [] + + +class Comment(BaseModel): + """Comment model with circular reference to Post.""" + + id: int + content: str + post: "Post" + + +class TreeNode(BaseModel): + """Tree node with self-reference.""" + + id: int + value: str + children: List["TreeNode"] = [] + parent: Optional["TreeNode"] = None + + +def create_circular_reference_app() -> FastAPI: + """Create a FastAPI app with circular reference models.""" + app = FastAPI(title="Circular Reference Test", version="1.0.0") + + @app.get("/users/{user_id}", operation_id="get_user") + async def get_user(user_id: int) -> User: + """Get a user by ID.""" + return User(id=user_id, name="Test User", posts=[]) + + @app.get("/posts/{post_id}", operation_id="get_post") + async def get_post(post_id: int) -> Post: + """Get a post by ID.""" + return Post(id=post_id, title="Test Post", author=User(id=1, name="Author")) + + @app.get("/comments/{comment_id}", operation_id="get_comment") + async def get_comment(comment_id: int) -> Comment: + """Get a comment by ID.""" + return Comment(id=comment_id, content="Test Comment", post=Post(id=1, title="Post")) + + @app.get("/tree/{node_id}", operation_id="get_tree_node") + async def get_tree_node(node_id: int) -> TreeNode: + """Get a tree node by ID.""" + return TreeNode(id=node_id, value="Test Node", children=[]) + + return app + + +def test_circular_reference_schema_resolution(): + """Test that circular references are handled gracefully.""" + app = create_circular_reference_app() + + # Generate OpenAPI schema + openapi_schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + + # Test that resolve_schema_references doesn't raise RecursionError + try: + resolved_schema = resolve_schema_references(openapi_schema, openapi_schema) + assert resolved_schema is not None + # Should not raise RecursionError + except RecursionError as e: + pytest.fail(f"RecursionError occurred: {e}") + + +def test_circular_reference_mcp_conversion(): + """Test that MCP conversion works with circular references.""" + app = create_circular_reference_app() + + # Generate OpenAPI schema + openapi_schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + + # Test that convert_openapi_to_mcp_tools doesn't raise RecursionError + try: + tools, operation_map = convert_openapi_to_mcp_tools(openapi_schema) + + # Should successfully convert without errors + assert len(tools) == 4 # get_user, get_post, get_comment, get_tree_node + assert len(operation_map) == 4 + + # Check that tools are created properly + for tool in tools: + assert tool.name in ["get_user", "get_post", "get_comment", "get_tree_node"] + assert tool.description is not None + assert tool.inputSchema is not None + + except RecursionError as e: + pytest.fail(f"RecursionError occurred during MCP conversion: {e}") + + +def test_self_reference_schema(): + """Test schema with self-reference (TreeNode).""" + app = create_circular_reference_app() + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + + # Test that self-referencing schemas are handled + try: + resolved_schema = resolve_schema_references(openapi_schema, openapi_schema) + + # Check that TreeNode schema is properly resolved + tree_node_schema = resolved_schema["components"]["schemas"]["TreeNode"] + assert "properties" in tree_node_schema + assert "children" in tree_node_schema["properties"] + + except RecursionError as e: + pytest.fail(f"RecursionError occurred with self-reference: {e}") + + +def test_complex_circular_reference(): + """Test complex circular reference chain: User -> Post -> Comment -> Post.""" + app = create_circular_reference_app() + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + + # Test that complex circular references are handled + try: + resolved_schema = resolve_schema_references(openapi_schema, openapi_schema) + + # Check that all schemas are resolved + assert "components" in resolved_schema + assert "schemas" in resolved_schema["components"] + + schemas = resolved_schema["components"]["schemas"] + assert "User" in schemas + assert "Post" in schemas + assert "Comment" in schemas + assert "TreeNode" in schemas + + except RecursionError as e: + pytest.fail(f"RecursionError occurred with complex circular reference: {e}") + + +def test_circular_reference_with_visited_paths(): + """Test that visited paths tracking works correctly.""" + # Create a simple schema with circular reference + schema_with_circular_ref = { + "components": { + "schemas": { + "ModelA": {"type": "object", "properties": {"b": {"$ref": "#/components/schemas/ModelB"}}}, + "ModelB": {"type": "object", "properties": {"a": {"$ref": "#/components/schemas/ModelA"}}}, + } + } + } + + # Test that circular reference is detected and handled + try: + resolved = resolve_schema_references(schema_with_circular_ref, schema_with_circular_ref) + assert resolved is not None + except RecursionError as e: + pytest.fail(f"RecursionError occurred with visited paths tracking: {e}") From 09ef03d5dfaeac8ae6677f9ef1719cbc502ca157 Mon Sep 17 00:00:00 2001 From: Edison-A-N Date: Tue, 21 Oct 2025 14:52:08 +0800 Subject: [PATCH 2/2] fix: add missing author parameter to Post model in test --- tests/test_circular_reference.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_circular_reference.py b/tests/test_circular_reference.py index 5a3781d..4f43137 100644 --- a/tests/test_circular_reference.py +++ b/tests/test_circular_reference.py @@ -63,7 +63,9 @@ async def get_post(post_id: int) -> Post: @app.get("/comments/{comment_id}", operation_id="get_comment") async def get_comment(comment_id: int) -> Comment: """Get a comment by ID.""" - return Comment(id=comment_id, content="Test Comment", post=Post(id=1, title="Post")) + return Comment( + id=comment_id, content="Test Comment", post=Post(id=1, title="Post", author=User(id=1, name="Author")) + ) @app.get("/tree/{node_id}", operation_id="get_tree_node") async def get_tree_node(node_id: int) -> TreeNode: