Skip to content

Commit 4badc10

Browse files
authored
Remove cross reference locally (#60)
* fix: remote cross reference * fix: remove other occ
1 parent fae4717 commit 4badc10

File tree

3 files changed

+2
-47
lines changed

3 files changed

+2
-47
lines changed

src/mcp_scan/MCPScanner.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from collections.abc import Callable
66
from typing import Any
77

8-
from mcp_scan.models import CrossRefResult, ScanError, ScanPathResult, ServerScanResult
8+
from mcp_scan.models import ScanError, ScanPathResult, ServerScanResult
99

1010
from .mcp_client import check_server_with_timeout, scan_mcp_config_file
1111
from .StorageFile import StorageFile
12-
from .utils import calculate_distance
1312
from .verify_api import verify_scan_path
1413

1514
# Set up logger for this module
@@ -192,39 +191,9 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu
192191
path_result.servers[i] = await self.scan_server(server, inspect_only)
193192
logger.debug("Verifying server path: %s", path)
194193
path_result = await verify_scan_path(path_result, base_url=self.base_url, run_locally=self.local_only)
195-
path_result.cross_ref_result = await self.check_cross_references(path_result)
196194
await self.emit("path_scanned", path_result)
197195
return path_result
198196

199-
async def check_cross_references(self, path_result: ScanPathResult) -> CrossRefResult:
200-
logger.info("Checking cross references for path: %s", path_result.path)
201-
cross_ref_result = CrossRefResult(found=False)
202-
for server in path_result.servers:
203-
other_servers = [s for s in path_result.servers if s != server]
204-
other_server_names = [s.name for s in other_servers]
205-
other_entity_names = [e.name for s in other_servers for e in s.entities]
206-
flagged_names = set(map(str.lower, other_server_names + other_entity_names))
207-
logger.debug("Found %d potential cross-reference names", len(flagged_names))
208-
209-
if len(flagged_names) < 1:
210-
logger.debug("No flagged names found, skipping cross-reference check")
211-
continue
212-
213-
for entity in server.entities:
214-
tokens = (entity.description or "").lower().split()
215-
for token in tokens:
216-
best_distance = calculate_distance(reference=token, responses=list(flagged_names))[0]
217-
if ((best_distance[1] <= 2) and (len(token) >= 5)) or (token in flagged_names):
218-
logger.warning("Cross-reference found: %s with token %s", entity.name, token)
219-
cross_ref_result.found = True
220-
cross_ref_result.sources.append(f"{entity.name}:{token}")
221-
222-
if cross_ref_result.found:
223-
logger.info("Cross references detected with %d sources", len(cross_ref_result.sources))
224-
else:
225-
logger.debug("No cross references found")
226-
return cross_ref_result
227-
228197
async def scan(self) -> list[ScanPathResult]:
229198
logger.info("Starting scan of %d paths", len(self.paths))
230199
if self.context_manager is not None:

src/mcp_scan/models.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,6 @@ class EntityScanResult(BaseModel):
156156
messages: list[str] = []
157157

158158

159-
class CrossRefResult(BaseModel):
160-
model_config = ConfigDict()
161-
found: bool | None = None
162-
sources: list[str] = []
163-
164-
165159
class ServerSignature(BaseModel):
166160
metadata: Metadata
167161
prompts: list[Prompt] = Field(default_factory=list)
@@ -227,7 +221,6 @@ class ScanPathResult(BaseModel):
227221
path: str
228222
servers: list[ServerScanResult] = []
229223
error: ScanError | None = None
230-
cross_ref_result: CrossRefResult | None = None
231224

232225
@property
233226
def entities(self) -> list[Entity]:

src/mcp_scan/printer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,7 @@ def print_scan_path_result(result: ScanPathResult, print_errors: bool = False) -
134134

135135
if len(result.servers) > 0:
136136
rich.print(path_print_tree)
137-
if result.cross_ref_result is not None and result.cross_ref_result.found:
138-
rich.print(
139-
rich.text.Text.from_markup(
140-
f"\n[bold yellow]:construction: Cross-Origin Violation: "
141-
f"Descriptions of server {result.cross_ref_result.sources} explicitly mention "
142-
f"tools or resources of other servers, or other servers.[/bold yellow]"
143-
),
144-
)
137+
145138
if print_errors and len(server_tracebacks) > 0:
146139
console = rich.console.Console()
147140
for server, traceback in server_tracebacks:

0 commit comments

Comments
 (0)