diff --git a/README.md b/README.md index c08b4ec..a87da60 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,11 @@ trailmark analyze path/to/project trailmark analyze --language rust path/to/project trailmark analyze --language javascript path/to/project +# Polyglot: auto-detect and merge every supported language found in the +# tree, or pass an explicit comma-separated list. +trailmark analyze --language auto path/to/project +trailmark analyze --language python,rust,solidity path/to/project + # Summary statistics trailmark analyze --summary path/to/project diff --git a/src/trailmark/query/api.py b/src/trailmark/query/api.py index 1467a04..f0bfb49 100644 --- a/src/trailmark/query/api.py +++ b/src/trailmark/query/api.py @@ -5,6 +5,7 @@ import importlib import json from dataclasses import asdict +from pathlib import Path from typing import Any from trailmark.analysis.augment import augment_from_sarif, augment_from_weaudit @@ -38,6 +39,30 @@ "masm": ("trailmark.parsers.masm", "MasmParser"), } +# Extensions used for language auto-detection. Kept in sync with each parser's +# internal _EXTENSIONS tuple. Shared extensions (e.g., `.h` between C and C++) +# are handled by prioritizing the more specific language — C++ is tried before +# plain C when both report files. +_LANGUAGE_EXTENSIONS: dict[str, tuple[str, ...]] = { + "python": (".py",), + "javascript": (".js", ".jsx", ".mjs", ".cjs"), + "typescript": (".ts", ".tsx"), + "php": (".php",), + "ruby": (".rb",), + "c": (".c",), + "cpp": (".cpp", ".cc", ".cxx", ".hpp", ".hh", ".hxx"), + "c_sharp": (".cs",), + "java": (".java",), + "go": (".go",), + "rust": (".rs",), + "solidity": (".sol",), + "cairo": (".cairo",), + "circom": (".circom",), + "haskell": (".hs",), + "erlang": (".erl",), + "masm": (".masm",), +} + _SUPPORTED_LANGUAGES = frozenset(_PARSER_MAP.keys()) @@ -52,6 +77,117 @@ def _get_parser(language: str) -> LanguageParser: return cls() +def _resolve_languages(path: str, spec: str) -> list[str]: + """Expand a ``language`` argument into a concrete list of languages. + + Accepts: + - ``"auto"`` — detect from file extensions under ``path``. + - ``"python,rust"`` — comma-separated explicit list. + - ``"python"`` — single language (the common case; returned as a + single-element list). + """ + if spec == "auto": + detected = detect_languages(path) + if not detected: + msg = f"No supported languages detected under {path}" + raise ValueError(msg) + return detected + names = [name.strip() for name in spec.split(",") if name.strip()] if "," in spec else [spec] + for name in names: + if name not in _PARSER_MAP: + msg = f"Unsupported language: {name}" + raise ValueError(msg) + return names + + +def _parse_and_merge(path: str, languages: list[str]) -> CodeGraph: + """Parse ``path`` with each language's parser and merge into one graph.""" + if len(languages) == 1: + # Preserves pre-polyglot behavior exactly for the common case. + return _get_parser(languages[0]).parse_directory(path) + + merged = CodeGraph( + language="polyglot", + root_path=str(Path(path).resolve()), + ) + for lang in languages: + sub = _get_parser(lang).parse_directory(path) + merged.merge(sub) + # merge() doesn't touch `language`; preserve the polyglot marker. + merged.language = "polyglot" + return merged + + +def detect_languages(path: str) -> list[str]: + """Return the sorted list of languages with at least one file under ``path``. + + Detection walks the directory once, classifies each file by extension, + and returns the languages that have at least one match. Order is the + order languages are registered in ``_LANGUAGE_EXTENSIONS``, which + roughly corresponds to popularity and keeps deterministic behavior. + """ + import os + + root = Path(path) + if not root.exists(): + return [] + + ext_to_language: dict[str, str] = {} + for lang, exts in _LANGUAGE_EXTENSIONS.items(): + for ext in exts: + # When languages share an extension (none currently do, but + # guard against it), the FIRST registration wins. + ext_to_language.setdefault(ext, lang) + + found: set[str] = set() + for dirpath, _dirs, files in os.walk(root): + # Skip common vendor / generated dirs to keep detection snappy. + if _should_skip_dir(dirpath): + continue + for name in files: + ext = _file_extension(name) + if ext in ext_to_language: + found.add(ext_to_language[ext]) + if len(found) == len(_LANGUAGE_EXTENSIONS): + break + + return [lang for lang in _LANGUAGE_EXTENSIONS if lang in found] + + +_SKIP_DIR_NAMES = frozenset( + { + ".git", + ".hg", + ".svn", + "node_modules", + "__pycache__", + ".venv", + "venv", + "env", + ".tox", + "dist", + "build", + "target", + ".mutants", + "mutants", + } +) + + +def _should_skip_dir(dirpath: str) -> bool: + """Return True for directories we should exclude from language detection.""" + parts = Path(dirpath).parts + return any(part in _SKIP_DIR_NAMES for part in parts) + + +def _file_extension(name: str) -> str: + """Return the lowercase extension including leading dot, or ''.""" + dot = name.rfind(".") + if dot < 0: + return "" + return name[dot:].lower() + + class QueryEngine: """Facade for building and querying code graphs.""" @@ -68,13 +204,18 @@ def from_directory( ) -> QueryEngine: """Parse a directory and return a ready-to-query engine. + ``language`` accepts a specific language name (e.g. ``"python"``, + ``"rust"``, ``"solidity"``), ``"auto"`` to detect and merge every + language with at least one matching file under ``path``, or a + comma-separated list like ``"python,rust"`` for an explicit set. + Entrypoint detection runs automatically so that ``attack_surface()`` and the entrypoint-dependent preanalysis passes have data to work with. Pass ``detect_entrypoints_=False`` to skip it (e.g. when the caller wants to drive detection separately). """ - parser = _get_parser(language) - graph = parser.parse_directory(path) + languages = _resolve_languages(path, language) + graph = _parse_and_merge(path, languages) if detect_entrypoints_: graph.entrypoints.update(detect_entrypoints(graph, path)) store = GraphStore(graph) diff --git a/tests/test_polyglot.py b/tests/test_polyglot.py new file mode 100644 index 0000000..a28cfa5 --- /dev/null +++ b/tests/test_polyglot.py @@ -0,0 +1,119 @@ +"""Tests for multi-language (polyglot) graph building. + +``QueryEngine.from_directory(path, language="auto")`` walks the tree, +detects every supported language with at least one matching file, and +merges the resulting graphs. Explicit ``"python,rust"``-style lists +are also supported. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from trailmark.query.api import QueryEngine, detect_languages + + +class TestDetectLanguages: + def test_single_language_directory(self, tmp_path: Path) -> None: + (tmp_path / "a.py").write_text("x = 1\n") + assert detect_languages(str(tmp_path)) == ["python"] + + def test_multiple_languages_in_one_dir(self, tmp_path: Path) -> None: + (tmp_path / "a.py").write_text("x = 1\n") + (tmp_path / "b.rs").write_text("fn x() {}\n") + (tmp_path / "c.sol").write_text("contract X {}\n") + detected = set(detect_languages(str(tmp_path))) + assert detected == {"python", "rust", "solidity"} + + def test_empty_directory_returns_empty(self, tmp_path: Path) -> None: + assert detect_languages(str(tmp_path)) == [] + + def test_unknown_extensions_ignored(self, tmp_path: Path) -> None: + (tmp_path / "README.md").write_text("# hello\n") + (tmp_path / "data.json").write_text("{}\n") + assert detect_languages(str(tmp_path)) == [] + + def test_skips_vendor_directories(self, tmp_path: Path) -> None: + (tmp_path / "app.py").write_text("x = 1\n") + vendor = tmp_path / "node_modules" / "dep" + vendor.mkdir(parents=True) + (vendor / "thing.js").write_text("export default {};\n") + detected = detect_languages(str(tmp_path)) + assert detected == ["python"], detected + + def test_missing_path_returns_empty(self, tmp_path: Path) -> None: + assert detect_languages(str(tmp_path / "does-not-exist")) == [] + + +class TestFromDirectoryAuto: + def test_auto_detects_and_merges(self, tmp_path: Path) -> None: + (tmp_path / "app.py").write_text( + "def handler():\n pass\n", + ) + (tmp_path / "Vault.sol").write_text( + "// SPDX-License-Identifier: MIT\n" + "pragma solidity ^0.8.0;\n" + "contract Vault {\n" + " function withdraw(uint256 amount) external {}\n" + "}\n", + ) + engine = QueryEngine.from_directory(str(tmp_path), language="auto") + summary = engine.summary() + # Should have found nodes from both languages. + assert summary["total_nodes"] >= 2 + + def test_auto_merges_entrypoints(self, tmp_path: Path) -> None: + """Detected entrypoints from different languages coexist.""" + (tmp_path / "cli.py").write_text("def main():\n pass\n") + (tmp_path / "Vault.sol").write_text( + "// SPDX-License-Identifier: MIT\n" + "pragma solidity ^0.8.0;\n" + "contract Vault {\n" + " function withdraw() external {}\n" + "}\n", + ) + engine = QueryEngine.from_directory(str(tmp_path), language="auto") + surface = engine.attack_surface() + descriptions = [ep.get("description") or "" for ep in surface] + assert any("main" in d.lower() for d in descriptions), surface + assert any("solidity" in d.lower() for d in descriptions), surface + + def test_auto_on_empty_dir_raises(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="No supported languages"): + QueryEngine.from_directory(str(tmp_path), language="auto") + + def test_explicit_list_merges(self, tmp_path: Path) -> None: + """`python,rust` builds and merges both, skipping other languages.""" + (tmp_path / "a.py").write_text("def main():\n pass\n") + (tmp_path / "b.rs").write_text("fn main() {}\n") + (tmp_path / "ignored.sol").write_text( + "contract X { function y() external {} }\n", + ) + engine = QueryEngine.from_directory( + str(tmp_path), + language="python,rust", + ) + summary = engine.summary() + # Solidity was explicitly excluded, so no contract nodes should + # appear. + surface_descriptions = [ + (ep.get("description") or "").lower() for ep in engine.attack_surface() + ] + assert not any("solidity" in d for d in surface_descriptions) + # But python and rust mains should both be detected. + assert summary["entrypoints"] >= 2 + + def test_unsupported_language_raises(self, tmp_path: Path) -> None: + (tmp_path / "a.py").write_text("x = 1\n") + with pytest.raises(ValueError, match="Unsupported language"): + QueryEngine.from_directory(str(tmp_path), language="cobol") + + def test_single_language_preserved(self, tmp_path: Path) -> None: + """Pre-polyglot behavior intact when one language is specified.""" + (tmp_path / "a.py").write_text("def main():\n pass\n") + engine = QueryEngine.from_directory(str(tmp_path), language="python") + surface = engine.attack_surface() + assert len(surface) == 1 + assert surface[0]["node_id"] == "a:main"