diff --git a/pyproject.toml b/pyproject.toml index dd8f6417..1789c73d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "beautifulsoup4==4.14.3", "pygls>=2.0,<3.0", "lsprotocol>=2024.0.0", + "mcp>=1.0", ] [project.urls] diff --git a/src/reqstool/command.py b/src/reqstool/command.py index 7c46ab79..592bb259 100755 --- a/src/reqstool/command.py +++ b/src/reqstool/command.py @@ -322,6 +322,11 @@ class ComboRawTextandArgsDefaultUltimateHelpFormatter( help="Write server logs to a file (in addition to stderr)", ) + # command: mcp + mcp_parser = subparsers.add_parser("mcp", help="Start the Model Context Protocol server (stdio)") + mcp_source_subparsers = mcp_parser.add_subparsers(dest="source", required=True) + self._add_subparsers_source(mcp_source_subparsers, include_report_options=False, include_filter_options=False) + args = self.__parser.parse_args() return args @@ -432,11 +437,26 @@ def command_lsp(self, lsp_args: argparse.Namespace): logging.fatal("reqstool LSP server crashed: %s", exc) sys.exit(1) + def command_mcp(self, mcp_args: argparse.Namespace): + try: + from reqstool.mcp.server import start_server + except ImportError: + print( + "MCP server requires extra dependencies: pip install 'mcp>=1.0'", + file=sys.stderr, + ) + sys.exit(1) + try: + start_server(location=self._get_initial_source(mcp_args)) + except Exception as exc: + logging.fatal("reqstool MCP server crashed: %s", exc) + sys.exit(1) + def print_help(self): self.__parser.print_help(sys.stderr) -def main(): +def main(): # noqa: C901 command = Command() args = command.get_arguments() @@ -466,6 +486,8 @@ def main(): exit_code = command.command_status(status_args=args) elif args.command == "lsp": command.command_lsp(lsp_args=args) + elif args.command == "mcp": + command.command_mcp(mcp_args=args) else: command.print_help() except MissingRequirementsFileError as exc: diff --git a/src/reqstool/commands/generate_json/generate_json.py b/src/reqstool/commands/generate_json/generate_json.py index d17b7c8c..471d520d 100644 --- a/src/reqstool/commands/generate_json/generate_json.py +++ b/src/reqstool/commands/generate_json/generate_json.py @@ -1,8 +1,6 @@ # Copyright © LFV -from __future__ import annotations - import json import logging diff --git a/src/reqstool/commands/report/report.py b/src/reqstool/commands/report/report.py index f2c89995..31dfb24a 100644 --- a/src/reqstool/commands/report/report.py +++ b/src/reqstool/commands/report/report.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from enum import Enum diff --git a/src/reqstool/commands/status/status.py b/src/reqstool/commands/status/status.py index 49105849..f05f8282 100644 --- a/src/reqstool/commands/status/status.py +++ b/src/reqstool/commands/status/status.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import json import re diff --git a/src/reqstool/common/models/lifecycle.py b/src/reqstool/common/models/lifecycle.py index 7e1773a7..080237c1 100644 --- a/src/reqstool/common/models/lifecycle.py +++ b/src/reqstool/common/models/lifecycle.py @@ -1,9 +1,8 @@ # Copyright © LFV -from __future__ import annotations from enum import Enum, unique -from typing import Optional +from typing import Optional, Self from pydantic import BaseModel, ConfigDict @@ -31,7 +30,7 @@ class LifecycleData(BaseModel): state: LIFECYCLESTATE = LIFECYCLESTATE.EFFECTIVE @classmethod - def from_dict(cls, data: Optional[dict]) -> LifecycleData: + def from_dict(cls, data: Optional[dict]) -> Self: if data is None: return cls(state=LIFECYCLESTATE.EFFECTIVE, reason=None) return cls( diff --git a/src/reqstool/common/project_session.py b/src/reqstool/common/project_session.py new file mode 100644 index 00000000..25bb8229 --- /dev/null +++ b/src/reqstool/common/project_session.py @@ -0,0 +1,91 @@ +# Copyright © LFV + + +import logging + +from reqstool.common.validators.lifecycle_validator import LifecycleValidator +from reqstool.common.validators.semantic_validator import SemanticValidator +from reqstool.common.validator_error_holder import ValidationErrorHolder +from reqstool.locations.location import LocationInterface +from reqstool.model_generators.combined_raw_datasets_generator import CombinedRawDatasetsGenerator +from reqstool.storage.database import RequirementsDatabase +from reqstool.storage.database_filter_processor import DatabaseFilterProcessor +from reqstool.storage.requirements_repository import RequirementsRepository + +logger = logging.getLogger(__name__) + + +class ProjectSession: + """Long-lived database session for a reqstool project loaded from any LocationInterface. + + Keeps the SQLite database open for the lifetime of the session (unlike the + build_database() context manager which closes on exit). Suitable for servers + (MCP, LSP) that need persistent read access after a one-time build. + """ + + def __init__(self, location: LocationInterface): + self._location = location + self._db: RequirementsDatabase | None = None + self._repo: RequirementsRepository | None = None + self._urn_source_paths: dict[str, dict[str, str]] = {} + self._ready: bool = False + self._error: str | None = None + + @property + def ready(self) -> bool: + return self._ready + + @property + def error(self) -> str | None: + return self._error + + @property + def repo(self) -> RequirementsRepository | None: + return self._repo + + @property + def urn_source_paths(self) -> dict[str, dict[str, str]]: + return self._urn_source_paths + + def build(self) -> None: + self.close() + self._error = None + db = RequirementsDatabase() + try: + holder = ValidationErrorHolder() + semantic_validator = SemanticValidator(validation_error_holder=holder) + + crdg = CombinedRawDatasetsGenerator( + initial_location=self._location, + semantic_validator=semantic_validator, + database=db, + ) + crd = crdg.combined_raw_datasets + + DatabaseFilterProcessor(db, crd.raw_datasets).apply_filters() + LifecycleValidator(RequirementsRepository(db)) + + self._db = db + self._repo = RequirementsRepository(db) + self._urn_source_paths = dict(crd.urn_source_paths) + self._ready = True + logger.info("Built project session for %s", self._location) + except SystemExit as e: + logger.warning("build() called sys.exit(%s) for %s", e.code, self._location) + self._error = f"Pipeline error (exit code {e.code})" + db.close() + except Exception as e: + logger.error("Failed to build project session for %s: %s", self._location, e) + self._error = str(e) + db.close() + + def rebuild(self) -> None: + self.build() + + def close(self) -> None: + if self._db is not None: + self._db.close() + self._db = None + self._repo = None + self._urn_source_paths = {} + self._ready = False diff --git a/src/reqstool/common/queries/__init__.py b/src/reqstool/common/queries/__init__.py new file mode 100644 index 00000000..051704bb --- /dev/null +++ b/src/reqstool/common/queries/__init__.py @@ -0,0 +1 @@ +# Copyright © LFV diff --git a/src/reqstool/common/queries/details.py b/src/reqstool/common/queries/details.py new file mode 100644 index 00000000..41dbd4d1 --- /dev/null +++ b/src/reqstool/common/queries/details.py @@ -0,0 +1,187 @@ +# Copyright © LFV + + +from reqstool.common.models.urn_id import UrnId +from reqstool.storage.requirements_repository import RequirementsRepository + + +def _svc_test_summary(svc_urn_id: UrnId, repo: RequirementsRepository) -> dict: + test_results = repo.get_test_results_for_svc(svc_urn_id) + return { + "passed": sum(1 for t in test_results if t.status.value == "passed"), + "failed": sum(1 for t in test_results if t.status.value == "failed"), + "skipped": sum(1 for t in test_results if t.status.value == "skipped"), + "missing": sum(1 for t in test_results if t.status.value == "missing"), + } + + +def get_requirement_details( + raw_id: str, + repo: RequirementsRepository, + urn_source_paths: dict[str, dict[str, str]] | None = None, +) -> dict | None: + initial_urn = repo.get_initial_urn() + urn_id = UrnId.assure_urn_id(initial_urn, raw_id) + all_reqs = repo.get_all_requirements() + req = all_reqs.get(urn_id) + if req is None: + return None + + svc_urn_ids = repo.get_svcs_for_req(req.id) + all_svcs = repo.get_all_svcs() + svcs = [all_svcs[uid] for uid in svc_urn_ids if uid in all_svcs] + + impls = repo.get_annotations_impls_for_req(req.id) + references = [str(ref_id) for rd in (req.references or []) for ref_id in rd.requirement_ids] + + paths = urn_source_paths or {} + return { + "type": "requirement", + "id": req.id.id, + "urn": req.id.urn, + "title": req.title, + "significance": req.significance.value, + "description": req.description, + "rationale": req.rationale or "", + "revision": str(req.revision), + "lifecycle": { + "state": req.lifecycle.state.value, + "reason": req.lifecycle.reason or "", + }, + "categories": [c.value for c in req.categories], + "implementation": req.implementation.value, + "references": references, + "implementations": [{"element_kind": a.element_kind, "fqn": a.fully_qualified_name} for a in impls], + "svcs": [ + { + "id": s.id.id, + "urn": s.id.urn, + "title": s.title, + "verification": s.verification.value, + "lifecycle_state": s.lifecycle.state.value, + "test_summary": _svc_test_summary(s.id, repo), + } + for s in svcs + ], + "location": repo.get_urn_location(req.id.urn), + "source_paths": paths.get(req.id.urn, {}), + } + + +def get_svc_details( + raw_id: str, + repo: RequirementsRepository, + urn_source_paths: dict[str, dict[str, str]] | None = None, +) -> dict | None: + initial_urn = repo.get_initial_urn() + urn_id = UrnId.assure_urn_id(initial_urn, raw_id) + all_svcs = repo.get_all_svcs() + svc = all_svcs.get(urn_id) + if svc is None: + return None + + mvr_urn_ids = repo.get_mvrs_for_svc(svc.id) + all_mvrs = repo.get_all_mvrs() + mvrs = [all_mvrs[uid] for uid in mvr_urn_ids if uid in all_mvrs] + + test_annotations = repo.get_annotations_tests_for_svc(svc.id) + test_results = repo.get_test_results_for_svc(svc.id) + + all_reqs = repo.get_all_requirements() + + paths = urn_source_paths or {} + return { + "type": "svc", + "id": svc.id.id, + "urn": svc.id.urn, + "title": svc.title, + "description": svc.description or "", + "verification": svc.verification.value, + "instructions": svc.instructions or "", + "revision": str(svc.revision), + "lifecycle": { + "state": svc.lifecycle.state.value, + "reason": svc.lifecycle.reason or "", + }, + "requirement_ids": [ + { + "id": r.id, + "urn": r.urn, + "title": all_reqs[r].title if r in all_reqs else "", + "lifecycle_state": all_reqs[r].lifecycle.state.value if r in all_reqs else "", + } + for r in svc.requirement_ids + ], + "test_annotations": [{"element_kind": a.element_kind, "fqn": a.fully_qualified_name} for a in test_annotations], + "test_results": [{"fqn": t.fully_qualified_name, "status": t.status.value} for t in test_results], + "test_summary": { + "passed": sum(1 for t in test_results if t.status.value == "passed"), + "failed": sum(1 for t in test_results if t.status.value == "failed"), + "skipped": sum(1 for t in test_results if t.status.value == "skipped"), + "missing": sum(1 for t in test_results if t.status.value == "missing"), + }, + "mvrs": [ + { + "id": m.id.id, + "urn": m.id.urn, + "passed": m.passed, + "comment": m.comment or "", + } + for m in mvrs + ], + "location": repo.get_urn_location(svc.id.urn), + "source_paths": paths.get(svc.id.urn, {}), + } + + +def get_mvr_details( + raw_id: str, + repo: RequirementsRepository, + urn_source_paths: dict[str, dict[str, str]] | None = None, +) -> dict | None: + initial_urn = repo.get_initial_urn() + urn_id = UrnId.assure_urn_id(initial_urn, raw_id) + all_mvrs = repo.get_all_mvrs() + mvr = all_mvrs.get(urn_id) + if mvr is None: + return None + + paths = urn_source_paths or {} + return { + "type": "mvr", + "id": mvr.id.id, + "urn": mvr.id.urn, + "passed": mvr.passed, + "comment": mvr.comment or "", + "svc_ids": [{"id": s.id, "urn": s.urn} for s in mvr.svc_ids], + "location": repo.get_urn_location(mvr.id.urn), + "source_paths": paths.get(mvr.id.urn, {}), + } + + +def get_requirement_status(raw_id: str, repo: RequirementsRepository) -> dict | None: + """Lightweight status check — avoids the full detail lookup.""" + initial_urn = repo.get_initial_urn() + urn_id = UrnId.assure_urn_id(initial_urn, raw_id) + req = repo.get_all_requirements().get(urn_id) + if req is None: + return None + + svc_urn_ids = repo.get_svcs_for_req(req.id) + test_summary = {"passed": 0, "failed": 0, "skipped": 0, "missing": 0} + for svc_uid in svc_urn_ids: + for t in repo.get_test_results_for_svc(svc_uid): + key = t.status.value + if key in test_summary: + test_summary[key] += 1 + + # skipped tests are not counted as failures; a requirement only "meets" if + # it has at least one implementation and no failed or missing test results + all_passing = test_summary["failed"] == 0 and test_summary["missing"] == 0 + return { + "id": req.id.id, + "lifecycle_state": req.lifecycle.state.value, + "implementation": req.implementation.value, + "test_summary": test_summary, + "meets_requirements": req.implementation.value != "not_implemented" and all_passing, + } diff --git a/src/reqstool/common/queries/list.py b/src/reqstool/common/queries/list.py new file mode 100644 index 00000000..ce5d3a3c --- /dev/null +++ b/src/reqstool/common/queries/list.py @@ -0,0 +1,45 @@ +# Copyright © LFV + + +from reqstool.storage.requirements_repository import RequirementsRepository + + +def get_requirements_list(repo: RequirementsRepository) -> list[dict]: + return [ + { + "id": r.id.id, + "title": r.title, + "lifecycle_state": r.lifecycle.state.value, + } + for r in repo.get_all_requirements().values() + ] + + +def get_svcs_list(repo: RequirementsRepository) -> list[dict]: + return [ + { + "id": s.id.id, + "title": s.title, + "lifecycle_state": s.lifecycle.state.value, + "verification": s.verification.value, + } + for s in repo.get_all_svcs().values() + ] + + +def get_mvrs_list(repo: RequirementsRepository) -> list[dict]: + return [ + { + "id": m.id.id, + "passed": m.passed, + } + for m in repo.get_all_mvrs().values() + ] + + +def get_list(repo: RequirementsRepository) -> dict: + return { + "requirements": get_requirements_list(repo), + "svcs": get_svcs_list(repo), + "mvrs": get_mvrs_list(repo), + } diff --git a/src/reqstool/common/validators/lifecycle_validator.py b/src/reqstool/common/validators/lifecycle_validator.py index 7e8a3a36..01dee66e 100644 --- a/src/reqstool/common/validators/lifecycle_validator.py +++ b/src/reqstool/common/validators/lifecycle_validator.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from collections import namedtuple import logging diff --git a/src/reqstool/lsp/annotation_parser.py b/src/reqstool/lsp/annotation_parser.py index 12ea5314..c4e8d155 100644 --- a/src/reqstool/lsp/annotation_parser.py +++ b/src/reqstool/lsp/annotation_parser.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import re from dataclasses import dataclass diff --git a/src/reqstool/lsp/features/code_actions.py b/src/reqstool/lsp/features/code_actions.py index 619882f8..6d7dae3f 100644 --- a/src/reqstool/lsp/features/code_actions.py +++ b/src/reqstool/lsp/features/code_actions.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import re diff --git a/src/reqstool/lsp/features/codelens.py b/src/reqstool/lsp/features/codelens.py index f924d11a..9f41eb2a 100644 --- a/src/reqstool/lsp/features/codelens.py +++ b/src/reqstool/lsp/features/codelens.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from lsprotocol import types diff --git a/src/reqstool/lsp/features/completion.py b/src/reqstool/lsp/features/completion.py index f5c101d3..38f88c2a 100644 --- a/src/reqstool/lsp/features/completion.py +++ b/src/reqstool/lsp/features/completion.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import os import re diff --git a/src/reqstool/lsp/features/definition.py b/src/reqstool/lsp/features/definition.py index ba967d81..5afe6d41 100644 --- a/src/reqstool/lsp/features/definition.py +++ b/src/reqstool/lsp/features/definition.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging import os diff --git a/src/reqstool/lsp/features/details.py b/src/reqstool/lsp/features/details.py index cb964d82..bd49f1c3 100644 --- a/src/reqstool/lsp/features/details.py +++ b/src/reqstool/lsp/features/details.py @@ -1,122 +1,19 @@ # Copyright © LFV -from __future__ import annotations +from reqstool.common.queries.details import get_mvr_details as _get_mvr_details +from reqstool.common.queries.details import get_requirement_details as _get_requirement_details +from reqstool.common.queries.details import get_svc_details as _get_svc_details from reqstool.lsp.project_state import ProjectState -def _svc_test_summary(svc_id: str, project: ProjectState) -> dict: - test_results = project.get_test_results_for_svc(svc_id) - return { - "passed": sum(1 for t in test_results if t.status.value == "passed"), - "failed": sum(1 for t in test_results if t.status.value == "failed"), - "skipped": sum(1 for t in test_results if t.status.value == "skipped"), - "missing": sum(1 for t in test_results if t.status.value == "missing"), - } - - def get_requirement_details(raw_id: str, project: ProjectState) -> dict | None: - req = project.get_requirement(raw_id) - if req is None: - return None - svcs = project.get_svcs_for_req(raw_id) - impls = project.get_impl_annotations_for_req(raw_id) - references = [str(ref_id) for rd in (req.references or []) for ref_id in rd.requirement_ids] - return { - "type": "requirement", - "id": req.id.id, - "urn": req.id.urn, - "title": req.title, - "significance": req.significance.value, - "description": req.description, - "rationale": req.rationale or "", - "revision": str(req.revision), - "lifecycle": { - "state": req.lifecycle.state.value, - "reason": req.lifecycle.reason or "", - }, - "categories": [c.value for c in req.categories], - "implementation": req.implementation.value, - "references": references, - "implementations": [{"element_kind": a.element_kind, "fqn": a.fully_qualified_name} for a in impls], - "svcs": [ - { - "id": s.id.id, - "urn": s.id.urn, - "title": s.title, - "verification": s.verification.value, - "lifecycle_state": s.lifecycle.state.value, - "test_summary": _svc_test_summary(s.id.id, project), - } - for s in svcs - ], - "location": project.get_urn_location(req.id.urn), - "source_paths": project.get_yaml_paths().get(req.id.urn, {}), - } + return _get_requirement_details(raw_id, project.repo, project.urn_source_paths) def get_svc_details(raw_id: str, project: ProjectState) -> dict | None: - svc = project.get_svc(raw_id) - if svc is None: - return None - mvrs = project.get_mvrs_for_svc(raw_id) - test_annotations = project.get_test_annotations_for_svc(raw_id) - test_results = project.get_test_results_for_svc(raw_id) - return { - "type": "svc", - "id": svc.id.id, - "urn": svc.id.urn, - "title": svc.title, - "description": svc.description or "", - "verification": svc.verification.value, - "instructions": svc.instructions or "", - "revision": str(svc.revision), - "lifecycle": { - "state": svc.lifecycle.state.value, - "reason": svc.lifecycle.reason or "", - }, - "requirement_ids": [ - { - "id": r.id, - "urn": r.urn, - "title": req.title if (req := project.get_requirement(r.id)) else "", - "lifecycle_state": req.lifecycle.state.value if req else "", - } - for r in svc.requirement_ids - ], - "test_annotations": [{"element_kind": a.element_kind, "fqn": a.fully_qualified_name} for a in test_annotations], - "test_results": [{"fqn": t.fully_qualified_name, "status": t.status.value} for t in test_results], - "test_summary": { - "passed": sum(1 for t in test_results if t.status.value == "passed"), - "failed": sum(1 for t in test_results if t.status.value == "failed"), - "skipped": sum(1 for t in test_results if t.status.value == "skipped"), - "missing": sum(1 for t in test_results if t.status.value == "missing"), - }, - "mvrs": [ - { - "id": m.id.id, - "urn": m.id.urn, - "passed": m.passed, - "comment": m.comment or "", - } - for m in mvrs - ], - "location": project.get_urn_location(svc.id.urn), - "source_paths": project.get_yaml_paths().get(svc.id.urn, {}), - } + return _get_svc_details(raw_id, project.repo, project.urn_source_paths) def get_mvr_details(raw_id: str, project: ProjectState) -> dict | None: - mvr = project.get_mvr(raw_id) - if mvr is None: - return None - return { - "type": "mvr", - "id": mvr.id.id, - "urn": mvr.id.urn, - "passed": mvr.passed, - "comment": mvr.comment or "", - "svc_ids": [{"id": s.id, "urn": s.urn} for s in mvr.svc_ids], - "location": project.get_urn_location(mvr.id.urn), - "source_paths": project.get_yaml_paths().get(mvr.id.urn, {}), - } + return _get_mvr_details(raw_id, project.repo, project.urn_source_paths) diff --git a/src/reqstool/lsp/features/diagnostics.py b/src/reqstool/lsp/features/diagnostics.py index 02b91945..4d40b188 100644 --- a/src/reqstool/lsp/features/diagnostics.py +++ b/src/reqstool/lsp/features/diagnostics.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging import os diff --git a/src/reqstool/lsp/features/document_symbols.py b/src/reqstool/lsp/features/document_symbols.py index 05806dc9..b2159b82 100644 --- a/src/reqstool/lsp/features/document_symbols.py +++ b/src/reqstool/lsp/features/document_symbols.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import os import re @@ -17,6 +16,32 @@ } +class _YamlItem: + """A parsed YAML list item with its fields and line range.""" + + __slots__ = ("fields", "start_line", "end_line", "id_line") + + def __init__(self, start_line: int): + self.fields: dict[str, str] = {} + self.start_line = start_line + self.end_line = start_line + self.id_line = start_line + + @property + def range(self) -> types.Range: + return types.Range( + start=types.Position(line=self.start_line, character=0), + end=types.Position(line=self.end_line, character=0), + ) + + @property + def selection_range(self) -> types.Range: + return types.Range( + start=types.Position(line=self.id_line, character=0), + end=types.Position(line=self.id_line, character=0), + ) + + def handle_document_symbols( uri: str, text: str, @@ -163,32 +188,6 @@ def _symbols_for_mvrs( return symbols -class _YamlItem: - """A parsed YAML list item with its fields and line range.""" - - __slots__ = ("fields", "start_line", "end_line", "id_line") - - def __init__(self, start_line: int): - self.fields: dict[str, str] = {} - self.start_line = start_line - self.end_line = start_line - self.id_line = start_line - - @property - def range(self) -> types.Range: - return types.Range( - start=types.Position(line=self.start_line, character=0), - end=types.Position(line=self.end_line, character=0), - ) - - @property - def selection_range(self) -> types.Range: - return types.Range( - start=types.Position(line=self.id_line, character=0), - end=types.Position(line=self.id_line, character=0), - ) - - def _parse_yaml_items(text: str) -> list[_YamlItem]: """Parse YAML text to extract list items under the main collection key. diff --git a/src/reqstool/lsp/features/hover.py b/src/reqstool/lsp/features/hover.py index f2101307..ef6b49bb 100644 --- a/src/reqstool/lsp/features/hover.py +++ b/src/reqstool/lsp/features/hover.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import json import os diff --git a/src/reqstool/lsp/features/implementation.py b/src/reqstool/lsp/features/implementation.py index 53b16153..f26775b8 100644 --- a/src/reqstool/lsp/features/implementation.py +++ b/src/reqstool/lsp/features/implementation.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging import os diff --git a/src/reqstool/lsp/features/inlay_hints.py b/src/reqstool/lsp/features/inlay_hints.py index 50a04995..5bfb9627 100644 --- a/src/reqstool/lsp/features/inlay_hints.py +++ b/src/reqstool/lsp/features/inlay_hints.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from lsprotocol import types diff --git a/src/reqstool/lsp/features/list.py b/src/reqstool/lsp/features/list.py index e70f18dd..fc518ca1 100644 --- a/src/reqstool/lsp/features/list.py +++ b/src/reqstool/lsp/features/list.py @@ -1,38 +1,11 @@ # Copyright © LFV -from __future__ import annotations +from reqstool.common.queries.list import get_list as _get_list from reqstool.lsp.project_state import ProjectState def get_list(project: ProjectState) -> dict: - reqs = project._repo.get_all_requirements() if project._repo else {} - svcs = project._repo.get_all_svcs() if project._repo else {} - mvrs = project._repo.get_all_mvrs() if project._repo else {} - - return { - "requirements": [ - { - "id": r.id.id, - "title": r.title, - "lifecycle_state": r.lifecycle.state.value, - } - for r in reqs.values() - ], - "svcs": [ - { - "id": s.id.id, - "title": s.title, - "lifecycle_state": s.lifecycle.state.value, - "verification": s.verification.value, - } - for s in svcs.values() - ], - "mvrs": [ - { - "id": m.id.id, - "passed": m.passed, - } - for m in mvrs.values() - ], - } + if project.repo is None: + return {"requirements": [], "svcs": [], "mvrs": []} + return _get_list(project.repo) diff --git a/src/reqstool/lsp/features/references.py b/src/reqstool/lsp/features/references.py index f157062e..290a873d 100644 --- a/src/reqstool/lsp/features/references.py +++ b/src/reqstool/lsp/features/references.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import os import re diff --git a/src/reqstool/lsp/features/semantic_tokens.py b/src/reqstool/lsp/features/semantic_tokens.py index 3517d5f0..6a694992 100644 --- a/src/reqstool/lsp/features/semantic_tokens.py +++ b/src/reqstool/lsp/features/semantic_tokens.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from lsprotocol import types diff --git a/src/reqstool/lsp/features/workspace_symbols.py b/src/reqstool/lsp/features/workspace_symbols.py index 7cb8c00a..8402a9aa 100644 --- a/src/reqstool/lsp/features/workspace_symbols.py +++ b/src/reqstool/lsp/features/workspace_symbols.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import os import re diff --git a/src/reqstool/lsp/project_state.py b/src/reqstool/lsp/project_state.py index 4ab81534..effe6a9f 100644 --- a/src/reqstool/lsp/project_state.py +++ b/src/reqstool/lsp/project_state.py @@ -1,92 +1,29 @@ # Copyright © LFV -from __future__ import annotations import logging from reqstool.common.models.urn_id import UrnId -from reqstool.common.validators.lifecycle_validator import LifecycleValidator -from reqstool.common.validators.semantic_validator import SemanticValidator -from reqstool.common.validator_error_holder import ValidationErrorHolder +from reqstool.common.project_session import ProjectSession from reqstool.locations.local_location import LocalLocation -from reqstool.model_generators.combined_raw_datasets_generator import CombinedRawDatasetsGenerator from reqstool.models.annotations import AnnotationData from reqstool.models.mvrs import MVRData from reqstool.models.requirements import RequirementData from reqstool.models.svcs import SVCData from reqstool.models.test_data import TestData -from reqstool.storage.database import RequirementsDatabase -from reqstool.storage.database_filter_processor import DatabaseFilterProcessor -from reqstool.storage.requirements_repository import RequirementsRepository logger = logging.getLogger(__name__) -class ProjectState: +class ProjectState(ProjectSession): def __init__(self, reqstool_path: str): + super().__init__(LocalLocation(path=reqstool_path)) self._reqstool_path = reqstool_path - self._db: RequirementsDatabase | None = None - self._repo: RequirementsRepository | None = None - self._ready: bool = False - self._error: str | None = None - self._urn_source_paths: dict[str, dict[str, str]] = {} - - @property - def ready(self) -> bool: - return self._ready - - @property - def error(self) -> str | None: - return self._error @property def reqstool_path(self) -> str: return self._reqstool_path - def build(self) -> None: - self.close() - self._error = None - db = RequirementsDatabase() - try: - location = LocalLocation(path=self._reqstool_path) - holder = ValidationErrorHolder() - semantic_validator = SemanticValidator(validation_error_holder=holder) - - crdg = CombinedRawDatasetsGenerator( - initial_location=location, - semantic_validator=semantic_validator, - database=db, - ) - crd = crdg.combined_raw_datasets - - DatabaseFilterProcessor(db, crd.raw_datasets).apply_filters() - LifecycleValidator(RequirementsRepository(db)) - - self._db = db - self._repo = RequirementsRepository(db) - self._urn_source_paths = dict(crd.urn_source_paths) - self._ready = True - logger.info("Built project state for %s", self._reqstool_path) - except SystemExit as e: - logger.warning("build_database() called sys.exit(%s) for %s", e.code, self._reqstool_path) - self._error = f"Pipeline error (exit code {e.code})" - db.close() - except Exception as e: - logger.error("Failed to build project state for %s: %s", self._reqstool_path, e) - self._error = str(e) - db.close() - - def rebuild(self) -> None: - self.build() - - def close(self) -> None: - if self._db is not None: - self._db.close() - self._db = None - self._repo = None - self._urn_source_paths = {} - self._ready = False - def get_initial_urn(self) -> str | None: if not self._ready or self._repo is None: return None diff --git a/src/reqstool/lsp/root_discovery.py b/src/reqstool/lsp/root_discovery.py index 44494026..62bdbd30 100644 --- a/src/reqstool/lsp/root_discovery.py +++ b/src/reqstool/lsp/root_discovery.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import fnmatch import logging diff --git a/src/reqstool/lsp/server.py b/src/reqstool/lsp/server.py index 297724fa..2d324478 100644 --- a/src/reqstool/lsp/server.py +++ b/src/reqstool/lsp/server.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging diff --git a/src/reqstool/lsp/workspace_manager.py b/src/reqstool/lsp/workspace_manager.py index 74ee3efc..cb971b2b 100644 --- a/src/reqstool/lsp/workspace_manager.py +++ b/src/reqstool/lsp/workspace_manager.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging import os diff --git a/src/reqstool/lsp/yaml_schema.py b/src/reqstool/lsp/yaml_schema.py index 565876fa..e537e0a2 100644 --- a/src/reqstool/lsp/yaml_schema.py +++ b/src/reqstool/lsp/yaml_schema.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import json import logging diff --git a/src/reqstool/mcp/__init__.py b/src/reqstool/mcp/__init__.py new file mode 100644 index 00000000..051704bb --- /dev/null +++ b/src/reqstool/mcp/__init__.py @@ -0,0 +1 @@ +# Copyright © LFV diff --git a/src/reqstool/mcp/server.py b/src/reqstool/mcp/server.py new file mode 100644 index 00000000..fd4ba78c --- /dev/null +++ b/src/reqstool/mcp/server.py @@ -0,0 +1,112 @@ +# Copyright © LFV + + +import logging + +from reqstool.common.project_session import ProjectSession +from reqstool.common.queries.details import ( + get_mvr_details, + get_requirement_details, + get_requirement_status as _get_requirement_status, + get_svc_details, +) +from reqstool.common.queries.list import get_mvrs_list, get_requirements_list, get_svcs_list +from reqstool.locations.location import LocationInterface +from reqstool.services.statistics_service import StatisticsService +from reqstool.storage.requirements_repository import RequirementsRepository + +logger = logging.getLogger(__name__) + + +def start_server(location: LocationInterface) -> None: # noqa: C901 + try: + from mcp.server.fastmcp import FastMCP + except ImportError as exc: + raise ImportError("MCP server requires extra dependencies: pip install 'mcp>=1.0'") from exc + + session = ProjectSession(location) + session.build() + + if not session.ready: + raise RuntimeError(f"Failed to load reqstool project: {session.error}") + + if session.repo is None: + raise RuntimeError("Project session repo is None after successful build") + repo: RequirementsRepository = session.repo + urn_source_paths = session.urn_source_paths + + mcp = FastMCP("reqstool") + + @mcp.tool() + def list_requirements() -> list[dict]: + """List all requirements with id, title, and lifecycle state.""" + return get_requirements_list(repo) + + @mcp.tool() + def get_requirement(id: str) -> dict: + """Get full details for a requirement by ID (e.g. REQ_010).""" + result = get_requirement_details(id, repo, urn_source_paths) + if result is None: + raise ValueError(f"Requirement {id!r} not found") + return result + + @mcp.tool() + def list_svcs() -> list[dict]: + """List all SVCs with id, title, lifecycle state, and verification type.""" + return get_svcs_list(repo) + + @mcp.tool() + def get_svc(id: str) -> dict: + """Get full details for an SVC by ID (e.g. SVC_010).""" + result = get_svc_details(id, repo, urn_source_paths) + if result is None: + raise ValueError(f"SVC {id!r} not found") + return result + + @mcp.tool() + def list_mvrs() -> list[dict]: + """List all MVRs with id and passed status.""" + return get_mvrs_list(repo) + + @mcp.tool() + def get_mvr(id: str) -> dict: + """Get full details for an MVR by ID.""" + result = get_mvr_details(id, repo, urn_source_paths) + if result is None: + raise ValueError(f"MVR {id!r} not found") + return result + + @mcp.tool() + def get_status() -> dict: + """Get overall traceability status — completion per requirement, test totals.""" + return StatisticsService(repo).to_status_dict() + + @mcp.tool() + def get_requirement_status(id: str) -> dict: + """Quick status check for one requirement: lifecycle state, implementation status, test summary.""" + result = _get_requirement_status(id, repo) + if result is None: + raise ValueError(f"Requirement {id!r} not found") + return result + + @mcp.tool() + def list_annotations() -> list[dict]: + """List all implementation annotations (@Requirements) found in source code.""" + impl_annotations = repo.get_annotations_impls() + result = [] + for urn_id, ann_list in impl_annotations.items(): + for ann in ann_list: + result.append( + { + "req_id": urn_id.id, + "req_urn": urn_id.urn, + "element_kind": ann.element_kind, + "fqn": ann.fully_qualified_name, + } + ) + return result + + try: + mcp.run() + finally: + session.close() diff --git a/src/reqstool/services/export_service.py b/src/reqstool/services/export_service.py index 732a4ad0..9acf6338 100644 --- a/src/reqstool/services/export_service.py +++ b/src/reqstool/services/export_service.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging diff --git a/src/reqstool/services/statistics_service.py b/src/reqstool/services/statistics_service.py index 12a9af4f..a75495dc 100644 --- a/src/reqstool/services/statistics_service.py +++ b/src/reqstool/services/statistics_service.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from dataclasses import dataclass, field diff --git a/src/reqstool/storage/database.py b/src/reqstool/storage/database.py index 3b607b17..00d81788 100644 --- a/src/reqstool/storage/database.py +++ b/src/reqstool/storage/database.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging import sqlite3 diff --git a/src/reqstool/storage/database_filter_processor.py b/src/reqstool/storage/database_filter_processor.py index c9c656f4..1c6a9725 100644 --- a/src/reqstool/storage/database_filter_processor.py +++ b/src/reqstool/storage/database_filter_processor.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import logging diff --git a/src/reqstool/storage/el_to_sql_compiler.py b/src/reqstool/storage/el_to_sql_compiler.py index ffd6b5bf..fcab1321 100644 --- a/src/reqstool/storage/el_to_sql_compiler.py +++ b/src/reqstool/storage/el_to_sql_compiler.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations import re diff --git a/src/reqstool/storage/pipeline.py b/src/reqstool/storage/pipeline.py index 9304088c..8436dac2 100644 --- a/src/reqstool/storage/pipeline.py +++ b/src/reqstool/storage/pipeline.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from contextlib import contextmanager from typing import Generator diff --git a/src/reqstool/storage/requirements_repository.py b/src/reqstool/storage/requirements_repository.py index 39251284..83205046 100644 --- a/src/reqstool/storage/requirements_repository.py +++ b/src/reqstool/storage/requirements_repository.py @@ -1,6 +1,5 @@ # Copyright © LFV -from __future__ import annotations from packaging.version import Version diff --git a/tests/integration/reqstool/lsp/conftest.py b/tests/integration/reqstool/lsp/conftest.py index 3e6c448c..5c3da0c6 100644 --- a/tests/integration/reqstool/lsp/conftest.py +++ b/tests/integration/reqstool/lsp/conftest.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import asyncio import os import sys diff --git a/tests/integration/reqstool/lsp/test_lsp_integration.py b/tests/integration/reqstool/lsp/test_lsp_integration.py index 88a83035..0b18b513 100644 --- a/tests/integration/reqstool/lsp/test_lsp_integration.py +++ b/tests/integration/reqstool/lsp/test_lsp_integration.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os from pathlib import Path diff --git a/tests/integration/reqstool/mcp/conftest.py b/tests/integration/reqstool/mcp/conftest.py new file mode 100644 index 00000000..8e65c892 --- /dev/null +++ b/tests/integration/reqstool/mcp/conftest.py @@ -0,0 +1,58 @@ +# Copyright © LFV + + +import asyncio +import sys +from pathlib import Path + +import pytest +import pytest_asyncio +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +FIXTURE_DIR = str(Path(__file__).resolve().parents[3] / "fixtures" / "reqstool-regression-python") + +pytestmark = [pytest.mark.integration, pytest.mark.asyncio(loop_scope="session")] + + +@pytest.fixture(scope="session") +def fixture_dir(): + import os + + assert os.path.isdir(FIXTURE_DIR), f"Fixture directory not found: {FIXTURE_DIR}" + return FIXTURE_DIR + + +@pytest_asyncio.fixture(loop_scope="session", scope="session") +async def mcp_session(fixture_dir): + """Session-scoped async fixture: start MCP server, initialize session, yield, shutdown. + + The entire stdio_client + ClientSession lifecycle runs inside a single asyncio Task + so that anyio cancel scopes are always entered and exited by the same task. + """ + ready: asyncio.Queue = asyncio.Queue() + done = asyncio.Event() + + async def _lifecycle(): + params = StdioServerParameters( + command=sys.executable, + args=["-m", "reqstool.command", "mcp", "local", "-p", fixture_dir], + ) + try: + async with stdio_client(params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + await ready.put(session) + await done.wait() + except Exception as exc: + await ready.put(exc) + + task = asyncio.create_task(_lifecycle()) + result = await ready.get() + if isinstance(result, Exception): + raise result + + yield result + + done.set() + await task diff --git a/tests/integration/reqstool/mcp/test_mcp_integration.py b/tests/integration/reqstool/mcp/test_mcp_integration.py new file mode 100644 index 00000000..c48aa1fd --- /dev/null +++ b/tests/integration/reqstool/mcp/test_mcp_integration.py @@ -0,0 +1,184 @@ +# Copyright © LFV + + +import json + +import pytest + +pytestmark = [pytest.mark.integration, pytest.mark.asyncio(loop_scope="session")] + +# IDs present in the reqstool-regression-python fixture +KNOWN_REQ_ID = "REQ_PASS" +KNOWN_SVC_ID = "SVC_010" + + +def _parse_result(result) -> list | dict: + """FastMCP returns each list item as a separate TextContent block.""" + blocks = [json.loads(b.text) for b in result.content if hasattr(b, "text")] + return blocks if len(blocks) != 1 else blocks[0] + + +# --------------------------------------------------------------------------- +# Tool discovery +# --------------------------------------------------------------------------- + + +async def test_list_tools(mcp_session): + """Server advertises all 9 expected tools.""" + result = await mcp_session.list_tools() + tool_names = {t.name for t in result.tools} + expected = { + "list_requirements", + "get_requirement", + "list_svcs", + "get_svc", + "list_mvrs", + "get_mvr", + "get_status", + "get_requirement_status", + "list_annotations", + } + assert expected.issubset(tool_names), f"Missing tools: {expected - tool_names}" + + +# --------------------------------------------------------------------------- +# list_requirements +# --------------------------------------------------------------------------- + + +async def test_list_requirements(mcp_session): + result = await mcp_session.call_tool("list_requirements", {}) + reqs = _parse_result(result) + assert isinstance(reqs, list) + assert len(reqs) > 0 + for req in reqs: + assert "id" in req + assert "title" in req + assert "lifecycle_state" in req + + +# --------------------------------------------------------------------------- +# get_requirement +# --------------------------------------------------------------------------- + + +async def test_get_requirement_known(mcp_session): + result = await mcp_session.call_tool("get_requirement", {"id": KNOWN_REQ_ID}) + req = _parse_result(result) + assert req["id"] == KNOWN_REQ_ID + assert req["type"] == "requirement" + assert "svcs" in req + assert "implementations" in req + assert "lifecycle" in req + assert "source_paths" in req + + +async def test_get_requirement_not_found(mcp_session): + result = await mcp_session.call_tool("get_requirement", {"id": "REQ_NONEXISTENT"}) + assert result.isError + + +# --------------------------------------------------------------------------- +# list_svcs +# --------------------------------------------------------------------------- + + +async def test_list_svcs(mcp_session): + result = await mcp_session.call_tool("list_svcs", {}) + svcs = _parse_result(result) + assert isinstance(svcs, list) + assert len(svcs) > 0 + for svc in svcs: + assert "id" in svc + assert "title" in svc + assert "lifecycle_state" in svc + assert "verification" in svc + + +# --------------------------------------------------------------------------- +# get_svc +# --------------------------------------------------------------------------- + + +async def test_get_svc_known(mcp_session): + result = await mcp_session.call_tool("get_svc", {"id": KNOWN_SVC_ID}) + svc = _parse_result(result) + assert svc["id"] == KNOWN_SVC_ID + assert svc["type"] == "svc" + assert "test_summary" in svc + assert "requirement_ids" in svc + assert "mvrs" in svc + + +async def test_get_svc_not_found(mcp_session): + result = await mcp_session.call_tool("get_svc", {"id": "SVC_NONEXISTENT"}) + assert result.isError + + +# --------------------------------------------------------------------------- +# list_mvrs / get_mvr +# --------------------------------------------------------------------------- + + +async def test_list_mvrs(mcp_session): + result = await mcp_session.call_tool("list_mvrs", {}) + mvrs = _parse_result(result) + assert isinstance(mvrs, list) + for mvr in mvrs: + assert "id" in mvr + assert "passed" in mvr + + +async def test_get_mvr_not_found(mcp_session): + result = await mcp_session.call_tool("get_mvr", {"id": "MVR_NONEXISTENT"}) + assert result.isError + + +# --------------------------------------------------------------------------- +# get_status +# --------------------------------------------------------------------------- + + +async def test_get_status(mcp_session): + result = await mcp_session.call_tool("get_status", {}) + status = _parse_result(result) + assert "requirements" in status + assert "totals" in status + + +# --------------------------------------------------------------------------- +# get_requirement_status +# --------------------------------------------------------------------------- + + +async def test_get_requirement_status(mcp_session): + result = await mcp_session.call_tool("get_requirement_status", {"id": KNOWN_REQ_ID}) + status = _parse_result(result) + assert status["id"] == KNOWN_REQ_ID + assert "lifecycle_state" in status + assert "implementation" in status + assert "test_summary" in status + assert "meets_requirements" in status + assert isinstance(status["meets_requirements"], bool) + + +async def test_get_requirement_status_not_found(mcp_session): + result = await mcp_session.call_tool("get_requirement_status", {"id": "REQ_NONEXISTENT"}) + assert result.isError + + +# --------------------------------------------------------------------------- +# list_annotations +# --------------------------------------------------------------------------- + + +async def test_list_annotations(mcp_session): + result = await mcp_session.call_tool("list_annotations", {}) + annotations = _parse_result(result) + assert isinstance(annotations, list) + assert len(annotations) > 0 + for ann in annotations: + assert "req_id" in ann + assert "req_urn" in ann + assert "element_kind" in ann + assert "fqn" in ann diff --git a/tests/unit/reqstool/common/queries/__init__.py b/tests/unit/reqstool/common/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/reqstool/common/queries/test_details.py b/tests/unit/reqstool/common/queries/test_details.py new file mode 100644 index 00000000..809a6d05 --- /dev/null +++ b/tests/unit/reqstool/common/queries/test_details.py @@ -0,0 +1,109 @@ +# Copyright © LFV + +import pytest + +from reqstool.common.project_session import ProjectSession +from reqstool.common.queries.details import ( + get_mvr_details, + get_requirement_details, + get_requirement_status, + get_svc_details, +) +from reqstool.locations.local_location import LocalLocation + + +@pytest.fixture +def session(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + s = ProjectSession(LocalLocation(path=path)) + s.build() + yield s + s.close() + + +def test_get_requirement_details_known(session): + result = get_requirement_details("REQ_010", session.repo) + assert result is not None + assert result["type"] == "requirement" + assert result["id"] == "REQ_010" + assert "title" in result + assert "significance" in result + assert "description" in result + assert "lifecycle" in result + assert isinstance(result["references"], list) + assert isinstance(result["implementations"], list) + assert isinstance(result["svcs"], list) + assert "location" in result + assert result["source_paths"] == {} # no urn_source_paths passed + + +def test_get_requirement_details_with_source_paths(session): + result = get_requirement_details("REQ_010", session.repo, session.urn_source_paths) + assert result is not None + assert isinstance(result["source_paths"], dict) + + +def test_get_requirement_details_unknown(session): + assert get_requirement_details("REQ_NONEXISTENT", session.repo) is None + + +def test_get_requirement_details_implementations(session): + result = get_requirement_details("REQ_010", session.repo, session.urn_source_paths) + assert result is not None + assert len(result["implementations"]) > 0 + impl = result["implementations"][0] + assert "element_kind" in impl + assert "fqn" in impl + + +def test_get_svc_details_known(session): + repo = session.repo + svc_ids = [uid.id for uid in repo.get_all_svcs()] + assert svc_ids + result = get_svc_details(svc_ids[0], repo) + assert result is not None + assert result["type"] == "svc" + assert "title" in result + assert "verification" in result + assert "requirement_ids" in result + assert "test_summary" in result + assert set(result["test_summary"].keys()) == {"passed", "failed", "skipped", "missing"} + assert "mvrs" in result + + +def test_get_svc_details_unknown(session): + assert get_svc_details("SVC_NONEXISTENT", session.repo) is None + + +def test_get_svc_details_requirement_ids_enriched(session): + repo = session.repo + svc_ids = [uid.id for uid in repo.get_all_svcs()] + for svc_id in svc_ids: + result = get_svc_details(svc_id, repo) + assert result is not None + for req_entry in result["requirement_ids"]: + assert "id" in req_entry + assert "urn" in req_entry + assert "title" in req_entry + assert "lifecycle_state" in req_entry + break + + +def test_get_mvr_details_unknown(session): + assert get_mvr_details("MVR_NONEXISTENT", session.repo) is None + + +def test_get_requirement_status_known(session): + result = get_requirement_status("REQ_010", session.repo) + assert result is not None + assert result["id"] == "REQ_010" + assert "lifecycle_state" in result + assert "implementation" in result + assert "test_summary" in result + assert set(result["test_summary"].keys()) == {"passed", "failed", "skipped", "missing"} + assert "meets_requirements" in result + assert isinstance(result["meets_requirements"], bool) + + +def test_get_requirement_status_unknown(session): + assert get_requirement_status("REQ_NONEXISTENT", session.repo) is None diff --git a/tests/unit/reqstool/common/queries/test_list.py b/tests/unit/reqstool/common/queries/test_list.py new file mode 100644 index 00000000..97052673 --- /dev/null +++ b/tests/unit/reqstool/common/queries/test_list.py @@ -0,0 +1,83 @@ +# Copyright © LFV + +import pytest + +from reqstool.common.project_session import ProjectSession +from reqstool.common.queries.list import get_list, get_mvrs_list, get_requirements_list, get_svcs_list +from reqstool.locations.local_location import LocalLocation + + +@pytest.fixture +def repo(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + session = ProjectSession(LocalLocation(path=path)) + session.build() + yield session.repo + session.close() + + +def test_get_list_structure(repo): + result = get_list(repo) + assert isinstance(result, dict) + assert "requirements" in result + assert "svcs" in result + assert "mvrs" in result + + +def test_get_list_requirements(repo): + result = get_list(repo) + reqs = result["requirements"] + assert len(reqs) > 0 + for req in reqs: + assert "id" in req + assert "title" in req + assert "lifecycle_state" in req + assert isinstance(req["id"], str) + assert isinstance(req["title"], str) + + +def test_get_list_svcs(repo): + result = get_list(repo) + svcs = result["svcs"] + assert len(svcs) > 0 + for svc in svcs: + assert "id" in svc + assert "title" in svc + assert "lifecycle_state" in svc + assert "verification" in svc + + +def test_get_list_mvrs(repo): + result = get_list(repo) + # MVRs may be empty in this fixture — just check structure + for mvr in result["mvrs"]: + assert "id" in mvr + assert "passed" in mvr + assert isinstance(mvr["passed"], bool) + + +def test_get_requirements_list(repo): + reqs = get_requirements_list(repo) + assert isinstance(reqs, list) + assert len(reqs) > 0 + for req in reqs: + assert "id" in req + assert "title" in req + assert "lifecycle_state" in req + + +def test_get_svcs_list(repo): + svcs = get_svcs_list(repo) + assert isinstance(svcs, list) + assert len(svcs) > 0 + for svc in svcs: + assert "id" in svc + assert "verification" in svc + + +def test_get_mvrs_list(repo): + mvrs = get_mvrs_list(repo) + assert isinstance(mvrs, list) + for mvr in mvrs: + assert "id" in mvr + assert "passed" in mvr diff --git a/tests/unit/reqstool/common/test_project_session.py b/tests/unit/reqstool/common/test_project_session.py new file mode 100644 index 00000000..eac0991d --- /dev/null +++ b/tests/unit/reqstool/common/test_project_session.py @@ -0,0 +1,86 @@ +# Copyright © LFV + +from reqstool.common.project_session import ProjectSession +from reqstool.locations.local_location import LocalLocation + + +def test_build_standard_ms001(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + session = ProjectSession(LocalLocation(path=path)) + try: + session.build() + assert session.ready + assert session.error is None + assert session.repo is not None + assert session.repo.get_initial_urn() == "ms-001" + assert len(session.urn_source_paths) > 0 + finally: + session.close() + + +def test_build_basic_ms101(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_basic/baseline/ms-101") + session = ProjectSession(LocalLocation(path=path)) + try: + session.build() + assert session.ready + assert session.error is None + assert session.repo is not None + finally: + session.close() + + +def test_build_nonexistent_path(): + session = ProjectSession(LocalLocation(path="/nonexistent/path")) + session.build() + assert not session.ready + assert session.error is not None + assert session.repo is None + + +def test_rebuild(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + session = ProjectSession(LocalLocation(path=path)) + try: + session.build() + assert session.ready + session.rebuild() + assert session.ready + assert session.repo is not None + finally: + session.close() + + +def test_close_idempotent(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + session = ProjectSession(LocalLocation(path=path)) + session.build() + session.close() + assert not session.ready + assert session.repo is None + session.close() # should not raise + + +def test_urn_source_paths_populated(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + session = ProjectSession(LocalLocation(path=path)) + try: + session.build() + assert session.ready + paths = session.urn_source_paths + assert isinstance(paths, dict) + assert len(paths) > 0 + for urn, file_map in paths.items(): + assert isinstance(urn, str) + assert isinstance(file_map, dict) + finally: + session.close() + + +def test_urn_source_paths_cleared_on_close(local_testdata_resources_rootdir_w_path): + path = local_testdata_resources_rootdir_w_path("test_standard/baseline/ms-001") + session = ProjectSession(LocalLocation(path=path)) + session.build() + assert len(session.urn_source_paths) > 0 + session.close() + assert session.urn_source_paths == {} diff --git a/tests/unit/reqstool/lsp/test_details.py b/tests/unit/reqstool/lsp/test_details.py index 4ed9ffd4..1950d1da 100644 --- a/tests/unit/reqstool/lsp/test_details.py +++ b/tests/unit/reqstool/lsp/test_details.py @@ -2,7 +2,7 @@ import pytest -from reqstool.lsp.features.details import get_mvr_details, get_requirement_details, get_svc_details +from reqstool.common.queries.details import get_mvr_details, get_requirement_details, get_svc_details from reqstool.lsp.project_state import ProjectState @@ -16,7 +16,7 @@ def project(local_testdata_resources_rootdir_w_path): def test_get_requirement_details_known(project): - result = get_requirement_details("REQ_010", project) + result = get_requirement_details("REQ_010", project._repo, project.urn_source_paths) assert result is not None assert result["type"] == "requirement" assert result["id"] == "REQ_010" @@ -36,14 +36,14 @@ def test_get_requirement_details_known(project): def test_get_requirement_details_unknown(project): - result = get_requirement_details("REQ_NONEXISTENT", project) + result = get_requirement_details("REQ_NONEXISTENT", project._repo, project.urn_source_paths) assert result is None def test_get_svc_details_known(project): svc_ids = project.get_all_svc_ids() assert svc_ids, "No SVCs in test fixture" - result = get_svc_details(svc_ids[0], project) + result = get_svc_details(svc_ids[0], project._repo, project.urn_source_paths) assert result is not None assert result["type"] == "svc" assert result["id"] == svc_ids[0] @@ -65,18 +65,17 @@ def test_get_svc_details_known(project): def test_get_svc_details_unknown(project): - result = get_svc_details("SVC_NONEXISTENT", project) + result = get_svc_details("SVC_NONEXISTENT", project._repo, project.urn_source_paths) assert result is None def test_get_mvr_details_unknown(project): - # No MVRs in the test_standard fixture; get_mvr should return None - result = get_mvr_details("MVR_NONEXISTENT", project) + result = get_mvr_details("MVR_NONEXISTENT", project._repo, project.urn_source_paths) assert result is None def test_get_requirement_details_fields(project): - result = get_requirement_details("REQ_010", project) + result = get_requirement_details("REQ_010", project._repo, project.urn_source_paths) assert result is not None assert result["id"] == "REQ_010" assert result["lifecycle"]["state"] in ("draft", "effective", "deprecated", "obsolete") @@ -84,8 +83,7 @@ def test_get_requirement_details_fields(project): def test_get_requirement_details_implementations(project): - # annotations.yml has implementations for REQ_010 - result = get_requirement_details("REQ_010", project) + result = get_requirement_details("REQ_010", project._repo, project.urn_source_paths) assert result is not None assert len(result["implementations"]) > 0 impl = result["implementations"][0] @@ -97,7 +95,7 @@ def test_get_requirement_details_implementations(project): def test_get_svc_details_requirement_ids_enriched(project): svc_ids = project.get_all_svc_ids() for svc_id in svc_ids: - result = get_svc_details(svc_id, project) + result = get_svc_details(svc_id, project._repo, project.urn_source_paths) assert result is not None for req_entry in result["requirement_ids"]: assert "id" in req_entry @@ -108,11 +106,9 @@ def test_get_svc_details_requirement_ids_enriched(project): def test_get_svc_details_test_results(project): - # Find a SVC that has test annotations (SVCs in the fixture are linked to test methods) svc_ids = project.get_all_svc_ids() - # Look for an SVC that has test_annotations in the fixture for svc_id in svc_ids: - result = get_svc_details(svc_id, project) + result = get_svc_details(svc_id, project._repo, project.urn_source_paths) assert result is not None if result["test_annotations"]: assert all("element_kind" in a and "fqn" in a for a in result["test_annotations"]) @@ -122,10 +118,9 @@ def test_get_svc_details_test_results(project): def test_get_requirement_details_location_keys(project): - result = get_requirement_details("REQ_010", project) + result = get_requirement_details("REQ_010", project._repo, project.urn_source_paths) assert result is not None loc = result["location"] - # local fixture populates location_type and location_uri assert loc is None or isinstance(loc, dict) if loc is not None: assert "type" in loc @@ -136,7 +131,7 @@ def test_get_requirement_details_location_keys(project): def test_get_svc_details_location_keys(project): svc_ids = project.get_all_svc_ids() - result = get_svc_details(svc_ids[0], project) + result = get_svc_details(svc_ids[0], project._repo, project.urn_source_paths) assert result is not None loc = result["location"] assert loc is None or isinstance(loc, dict) diff --git a/tests/unit/reqstool/lsp/test_project_state.py b/tests/unit/reqstool/lsp/test_project_state.py index 01256ec5..2b0333b0 100644 --- a/tests/unit/reqstool/lsp/test_project_state.py +++ b/tests/unit/reqstool/lsp/test_project_state.py @@ -11,6 +11,7 @@ def test_build_standard_ms001(local_testdata_resources_rootdir_w_path): assert state.ready assert state.error is None assert state.get_initial_urn() == "ms-001" + assert len(state.urn_source_paths) > 0, "urn_source_paths should be populated after build" finally: state.close()