Skip to content

Remove cross reference locally #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions src/mcp_scan/MCPScanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from collections.abc import Callable
from typing import Any

from mcp_scan.models import CrossRefResult, ScanError, ScanPathResult, ServerScanResult
from mcp_scan.models import ScanError, ScanPathResult, ServerScanResult

from .mcp_client import check_server_with_timeout, scan_mcp_config_file
from .StorageFile import StorageFile
from .utils import calculate_distance
from .verify_api import verify_scan_path

# Set up logger for this module
Expand Down Expand Up @@ -192,39 +191,9 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu
path_result.servers[i] = await self.scan_server(server, inspect_only)
logger.debug("Verifying server path: %s", path)
path_result = await verify_scan_path(path_result, base_url=self.base_url, run_locally=self.local_only)
path_result.cross_ref_result = await self.check_cross_references(path_result)
await self.emit("path_scanned", path_result)
return path_result

async def check_cross_references(self, path_result: ScanPathResult) -> CrossRefResult:
logger.info("Checking cross references for path: %s", path_result.path)
cross_ref_result = CrossRefResult(found=False)
for server in path_result.servers:
other_servers = [s for s in path_result.servers if s != server]
other_server_names = [s.name for s in other_servers]
other_entity_names = [e.name for s in other_servers for e in s.entities]
flagged_names = set(map(str.lower, other_server_names + other_entity_names))
logger.debug("Found %d potential cross-reference names", len(flagged_names))

if len(flagged_names) < 1:
logger.debug("No flagged names found, skipping cross-reference check")
continue

for entity in server.entities:
tokens = (entity.description or "").lower().split()
for token in tokens:
best_distance = calculate_distance(reference=token, responses=list(flagged_names))[0]
if ((best_distance[1] <= 2) and (len(token) >= 5)) or (token in flagged_names):
logger.warning("Cross-reference found: %s with token %s", entity.name, token)
cross_ref_result.found = True
cross_ref_result.sources.append(f"{entity.name}:{token}")

if cross_ref_result.found:
logger.info("Cross references detected with %d sources", len(cross_ref_result.sources))
else:
logger.debug("No cross references found")
return cross_ref_result

async def scan(self) -> list[ScanPathResult]:
logger.info("Starting scan of %d paths", len(self.paths))
if self.context_manager is not None:
Expand Down
7 changes: 0 additions & 7 deletions src/mcp_scan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,6 @@ class EntityScanResult(BaseModel):
messages: list[str] = []


class CrossRefResult(BaseModel):
model_config = ConfigDict()
found: bool | None = None
sources: list[str] = []


class ServerSignature(BaseModel):
metadata: Metadata
prompts: list[Prompt] = Field(default_factory=list)
Expand Down Expand Up @@ -196,7 +190,6 @@ class ScanPathResult(BaseModel):
path: str
servers: list[ServerScanResult] = []
error: ScanError | None = None
cross_ref_result: CrossRefResult | None = None

@property
def entities(self) -> list[Entity]:
Expand Down
9 changes: 1 addition & 8 deletions src/mcp_scan/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,7 @@ def print_scan_path_result(result: ScanPathResult, print_errors: bool = False) -

if len(result.servers) > 0:
rich.print(path_print_tree)
if result.cross_ref_result is not None and result.cross_ref_result.found:
rich.print(
rich.text.Text.from_markup(
f"\n[bold yellow]:construction: Cross-Origin Violation: "
f"Descriptions of server {result.cross_ref_result.sources} explicitly mention "
f"tools or resources of other servers, or other servers.[/bold yellow]"
),
)

if print_errors and len(server_tracebacks) > 0:
console = rich.console.Console()
for server, traceback in server_tracebacks:
Expand Down
Loading