diff --git a/src/tostr/cli.py b/src/tostr/cli.py index bb94e32..5e9674a 100644 --- a/src/tostr/cli.py +++ b/src/tostr/cli.py @@ -5,11 +5,11 @@ from typing import Annotated from loguru import logger -from tostr.exceptions import ToasterError +from tostr.exceptions import TostrError from tostr.commands import init_async, inspect_async, skeleton_async, watch_async, clean_db -from tostr.mcp import mcp +from tostr.server import mcp from tostr.core.utils.logger import configure_cli_logging @@ -67,7 +67,7 @@ def watch( configure_cli_logging(debug) try: asyncio.run(watch_async(path)) - except ToasterError as e: + except TostrError as e: typer.secho(f"❌ Error: {e}", fg="red", err=True) raise typer.Exit(code=1) @@ -94,7 +94,7 @@ def clean( configure_cli_logging(debug) try: clean_db(path) - except ToasterError as e: + except TostrError as e: typer.secho(f"❌ Error: {e}", fg="red", err=True) raise typer.Exit(code=1) @@ -137,7 +137,7 @@ def init( start_time = time.perf_counter() try: asyncio.run(init_async(path, use_cache, ignore)) - except ToasterError as e: + except TostrError as e: typer.secho(f"❌ Error: {e}", fg="red", err=True) raise typer.Exit(code=1) @@ -190,7 +190,7 @@ def inspect( try: result = asyncio.run(inspect_async(id, path, include_body=include_body, pretty=pretty)) print(result) - except ToasterError as e: + except TostrError as e: typer.secho(f"❌ Error: {e}", fg="red", err=True) raise typer.Exit(code=1) @@ -237,7 +237,7 @@ def skeleton( try: result = asyncio.run(skeleton_async(subpath, path, pretty=pretty)) print(result) - except ToasterError as e: + except TostrError as e: typer.secho(f"❌ Error: {e}", fg="red", err=True) raise typer.Exit(code=1) end_time = time.perf_counter() diff --git a/src/tostr/commands.py b/src/tostr/commands.py index f55d499..2339e08 100644 --- a/src/tostr/commands.py +++ b/src/tostr/commands.py @@ -113,12 +113,12 @@ async def skeleton_async(subpath: str, project_path: Path, pretty: bool = True): active_tasks = {} -async def watch_async(target_path: Path): +async def watch_async(target_path: Path, stop_event: asyncio.Event = None): llm = get_llm_client() logger.info("Starting Listener") try: - async for changes in awatch(target_path): + async for changes in awatch(target_path, stop_event=stop_event): for change_type, path in changes: path = Path(path).relative_to(target_path) if ".tostr" in str(path): diff --git a/src/tostr/core/builders.py b/src/tostr/core/builders.py index e8640c3..58bb6d3 100644 --- a/src/tostr/core/builders.py +++ b/src/tostr/core/builders.py @@ -61,6 +61,7 @@ def from_dict(self, d: dict) -> BaseFile: description=d.get("description", ""), imports=d.get("imports", []), body=d.get("body", ""), + diff_hash=d.get("diff_hash", ""), package=d.get("package", ""), _inbound_dependency_strings=json.loads(d.get("inbound_dependency_strings", [])), _outbound_dependency_strings=json.loads(d.get("outbound_dependency_strings", [])), diff --git a/src/tostr/core/models.py b/src/tostr/core/models.py index 1828f86..60b7498 100644 --- a/src/tostr/core/models.py +++ b/src/tostr/core/models.py @@ -106,11 +106,6 @@ def edges(self): elif isinstance(self.parent, str): edges.add((self.id, self.parent, "is_child_of")) - # for child_set in self.children.values(): - # for child in child_set: - # edges.update(child.edges) - # edges.add((child.id, self.id, "contains")) - return edges @property @@ -143,6 +138,15 @@ def set_parent(self, parent: "BaseStruct"): logger.warning(f"Attempted to set parent of {self} to itself. Skipping to avoid circular reference.") return self.parent = parent + + def calculate_distributed_hash(self): + """Calculates diff_hash based on direct children's hashes.""" + if not self.all_children: + return + + child_hashes = sorted([child.diff_hash for child in self.all_children if getattr(child, "diff_hash", None)]) + if child_hashes: + self.diff_hash = hashlib.md5("".join(child_hashes).encode("utf-8")).hexdigest() def add_dependency(self, target: "BaseStruct"): self.outbound_dependencies.add(target) @@ -175,7 +179,6 @@ async def resolve_description_async(self, llm: "LLMClient", visited: set[str] = @classmethod def from_dict(cls, d: dict): data = d.copy() - # REMOVE all init=False here id = data.pop("id", None) instance = cls(**data) if id: @@ -213,6 +216,7 @@ def __str__(self): @dataclass(eq=False) class Directory(BaseStruct): _IDPREFIX: ClassVar[str] = "D" + diff_hash: str = "" def __init__(self, path, registry=None, parent=None, uid=None): uid = uid or str(path) @@ -233,11 +237,11 @@ def parse_children(self): for path in full_path.glob("*"): if self.registry.config.is_ignored(path): - logger.debug(f"Skipping '{path}' due to path ignore rules") + logger.debug(f"Skipping \'{path}\' due to path ignore rules") continue else: if path.is_dir(): - logger.debug(f"🔍 Parsing directory '{path}'") + logger.debug(f"🔍 Parsing directory \'{path}\'") relative_path = self.registry.relative_to_project(path) directory = Directory(path=relative_path, registry=self.registry, parent=self) self.registry.add_struct(directory) @@ -247,18 +251,26 @@ def parse_children(self): logger.debug(f"Attempting to resolve builder for suffix {path.parts[-1]}") try: if self.registry.config.is_ignored(path): - logger.debug(f"Skipping '{path}' due to path ignore rules") + logger.debug(f"Skipping \'{path}\' due to path ignore rules") continue builder = StructBuilderProvider.get_builder(path.suffix, self.registry) except LanguageNotSupportedError as e: continue instance = builder.build_file().from_path(path, parent=self) + + # Calculate file hash from children if any exist + instance.calculate_distributed_hash() + self.registry.add_struct(instance) self.add_child(instance) + + # Calculate directory hash from its direct children + self.calculate_distributed_hash() def to_dict(self) -> dict: data = super().to_dict() data["type"] = "Directory" + data["diff_hash"] = self.diff_hash return data @dataclass(eq=False) @@ -268,6 +280,7 @@ class BaseFile(BaseStruct): imports: List[str] = field(default_factory=list) package: str = "" body: str = "" + diff_hash: str = "" node: "Node" = None async def resolve_description_async(self, llm: "LLMClient", visited: set[str] = None): @@ -279,8 +292,9 @@ def to_dict(self) -> dict: data["type"] = "BaseFile" data["imports"] = self.imports data["body"] = self.body + data["diff_hash"] = self.diff_hash return data - + @dataclass(eq=False) class BaseCodeStruct(BaseStruct): @@ -329,7 +343,7 @@ def needs_description(self) -> bool: def imports(self) -> List[str]: return self.parent.imports - def resolve_type(self, type_name: str) -> Optional[BaseStruct]: + def resolve_type(self, type_name: str) -> Optional["BaseStruct"]: """Resolves a simple or scoped type name to a struct using package and imports.""" if not type_name: return None if type_name in self._type_cache: @@ -340,7 +354,7 @@ def resolve_type(self, type_name: str) -> Optional[BaseStruct]: # 2. Same package if not dep: - package = getattr(self.parent, 'package', None) if isinstance(self.parent, BaseFile) else None + package = getattr(self.parent, "package", None) if isinstance(self.parent, BaseFile) else None if package: dep = self.registry.get_struct_by_uid(f"{package}.{type_name}") @@ -462,7 +476,7 @@ class BaseMethod(BaseCodeStruct): # pass def resolve_dependencies(self): - logger.info(f"Resolving dependencies for method {self.uid}") + logger.debug(f"Resolving dependencies for method {self.uid}") for dep_info in self.dependency_names: # Handle both old (name, arity) and new (name, arity, receiver, is_creation) formats if len(dep_info) == 2: @@ -483,7 +497,7 @@ def resolve_dependencies(self): search_scope = self.parent.children if self.parent else self.children for child_set in list(search_scope.values()): for child in list(child_set): - if child.name == name and getattr(child, 'arity', -1) == arity: + if child.name == name and getattr(child, "arity", -1) == arity: self.add_dependency(child) resolved = True break @@ -512,7 +526,7 @@ def resolve_dependencies(self): if self.parent._potential_parents_cache is None: potential_parents = [] # Same package - package = getattr(self.parent.parent, 'package', None) if isinstance(self.parent.parent, BaseFile) else None + package = getattr(self.parent.parent, "package", None) if isinstance(self.parent.parent, BaseFile) else None if package: potential_parents.append(f"{package}.*") # All imports diff --git a/src/tostr/core/parser.py b/src/tostr/core/parser.py index 1eb2912..2518ebe 100644 --- a/src/tostr/core/parser.py +++ b/src/tostr/core/parser.py @@ -1,9 +1,10 @@ from pathlib import Path from abc import ABC import asyncio +import hashlib from loguru import logger -from tostr.core.models import BaseFile, Directory +from tostr.core.models import BaseFile, Directory, BaseStruct from tostr.core.registry import Registry from tostr.core.providers import StructBuilderProvider from tostr.exceptions import LanguageNotSupportedError @@ -13,13 +14,10 @@ def __init__(self, project_dir: str, llm=None, registry: Registry=None): self.project_dir = project_dir self.llm = llm self.registry = registry - # self.path_ignore = ["venv", ".venv", "env", ".env", "build", "dist", "__pycache__", ".tostr", ".git"] @property def files(self): - if self._files: return self._files - self._files = self.registry.uid_map.values().filter(lambda x: isinstance(x, BaseFile)) - return self._files + return [x for x in self.registry.uid_map.values() if isinstance(x, BaseFile)] async def parse(self, subpath: Path = None): if not subpath: @@ -28,9 +26,7 @@ async def parse(self, subpath: Path = None): subpath = Path(subpath) self.parse_path(subpath) - self.resolve_dependencies() - await self.resolve_descriptions_async() def parse_path(self, subpath: Path, parent: Directory = None): @@ -38,13 +34,11 @@ def parse_path(self, subpath: Path, parent: Directory = None): logger.debug(f"🔍 Parsing directory '{subpath}'") if parent is None: - # ROOT creation root_path = subpath if self.registry: root_path = self.registry.relative_to_project(subpath) root = Directory(path=root_path, registry=self.registry) self.registry.root = root - logger.debug(f"Created registry root: {root}") self.registry.add_struct(root) else: root = parent @@ -58,36 +52,27 @@ def parse_path(self, subpath: Path, parent: Directory = None): existing = self.registry.get_struct_by_uid(str(relative_path)) if path.is_dir(): - if existing and isinstance(existing, Directory): - logger.debug(f"Using cached directory '{path}'") - root.add_child(existing) - self.parse_path(path, parent=existing) - continue - directory = Directory(path=relative_path, registry=self.registry, parent=root) self.registry.add_struct(directory) root.add_child(directory) self.parse_path(path, parent=directory) else: - if existing and isinstance(existing, BaseFile): - if not self.registry.is_stale(existing): - logger.debug(f"Using cached file '{path}'") - root.add_child(existing) - continue - file = self.parse_file(path, parent=root) if file: + file.calculate_distributed_hash() self.registry.add_struct(file) root.add_child(file) + + root.calculate_distributed_hash() else: - logger.debug(f"🔍 Parsing file '{subpath}'") file = self.parse_file(subpath) - self.registry.root = file - self.registry.add_struct(file) + if file: + file.calculate_distributed_hash() + self.registry.root = file + self.registry.add_struct(file) - # @abstractmethod def parse_file(self, subpath: Path, parent: BaseStruct=None) -> BaseFile: - logger.debug(f"Attempting to resolve builder for suffix {subpath.parts[-1]}") + logger.debug(f"Attempting to resolve builder for suffix {subpath.suffix}") if self.registry.config.is_ignored(subpath): logger.debug(f"Skipping '{subpath}' due to path ignore rules") return None @@ -98,34 +83,19 @@ def parse_file(self, subpath: Path, parent: BaseStruct=None) -> BaseFile: logger.warning(str(e)) return None file_obj = builder.build_file().from_path(subpath, parent=parent) - # logger.debug(json.dumps(file_obj.to_dict(), indent=2)) return file_obj def resolve_dependencies(self): - logger.info(f"Starting dependency resolution from root: {self.registry.root}") - self.registry.root.resolve_dependencies() + if self.registry.root: + logger.info(f"Starting dependency resolution from root: {self.registry.root}") + self.registry.root.resolve_dependencies() def load_cache(self): - # print("Attempting to load cache from SQLite database...") - # t_cache = time.time() self.registry.load_cache() - # print(f"✅ Loaded Cache in {time.time() - t_cache:.2f} seconds") async def resolve_descriptions_async(self): - self.visited_ucids = set() - coroutine_list = [file.resolve_description_async(self.llm, self.visited_ucids) for file in self.registry.files] - if coroutine_list == []: return - result = await asyncio.gather(*coroutine_list) - - - # def write_skeleton(self): - # tost_string = tost.dump_parser(self, verbosity=Verbosity.SIMPLE) - # tostr_dir = self.path / ".tostr" - # tostr_dir.mkdir(exist_ok=True) - # with open(tostr_dir / "skeleton.tost", "w") as file: - # file.write(tost_string) - - # def write_cache(self, stale: bool = False): - # logger.debug("Writing AST to SQLite database...") - # self.registry.save_to_cache(stale=stale) - \ No newline at end of file + visited_ucids = set() + coroutine_list = [file.resolve_description_async(self.llm, visited_ucids) for file in self.registry.files] + if not coroutine_list: + return + await asyncio.gather(*coroutine_list) diff --git a/src/tostr/core/registry.py b/src/tostr/core/registry.py index afc4b8d..af6ce25 100644 --- a/src/tostr/core/registry.py +++ b/src/tostr/core/registry.py @@ -1,14 +1,15 @@ from collections import defaultdict from typing import List, Dict, Optional, TYPE_CHECKING from pathlib import Path -from tostr.core.models import BaseFile, BaseClass, BaseMethod, BaseField +import json +import hashlib +from loguru import logger + +from tostr.core.models import BaseFile, BaseClass, BaseMethod, BaseField, Directory from tostr.core.db import SQLiteCache from tostr.core.builders import BaseBuilder from tostr.core.context.config import ProjectConfig -import json -from loguru import logger - if TYPE_CHECKING: from tostr.core.models import BaseStruct, BaseCodeStruct @@ -42,7 +43,6 @@ def fields(self) -> List[BaseField]: def relative_to_project(self, path: Path) -> Path: if not self.project_path: - # logger.warning("Project path not set in registry, returning original path") return path if not path.is_absolute(): @@ -55,10 +55,9 @@ def relative_to_project(self, path: Path) -> Path: try: return path.absolute().relative_to(self.project_path.absolute()) except ValueError: - # If path is truly outside the subpath or due to macOS symlink quirks, fallback to relpath return Path(os.path.relpath(path.resolve(), self.project_path.resolve())) - def add_struct(self, struct: BaseStruct): + def add_struct(self, struct: "BaseStruct"): """ Adds a struct to the in-memory cache """ self.uid_map[struct.uid] = struct self.id_map[struct.id] = struct @@ -84,7 +83,6 @@ def _resolve_methods_recursive(self, struct: "BaseStruct", name: str, arity: int if struct.uid in visited: return [] visited.add(struct.uid) - # 1. Local methods matches = [x for x in struct.methods if x.name == name and x.arity == arity] if matches: return matches @@ -101,15 +99,12 @@ def _resolve_methods_recursive(self, struct: "BaseStruct", name: str, arity: int return [] def get_classes_in_package(self, package_name: str) -> List[BaseClass]: - """ Retrieves all classes in a specific package from memory or DB """ if package_name in self.missing_packages: return [x for x in self.classes if x.uid.startswith(package_name + ".") or x.uid.startswith(package_name + "#")] - # Ensure we have all classes from this package in memory if self.use_cache and self.db: with self.db.get_connection() as conn: cursor = conn.cursor() - # Find all classes in this package cursor.execute( "SELECT uid FROM structs WHERE type = 'BaseClass' AND (uid LIKE ? OR uid LIKE ?)", (f"{package_name}.%", f"{package_name}#%") @@ -125,52 +120,44 @@ def get_classes_in_package(self, package_name: str) -> List[BaseClass]: def load_filepath(self, path: Path): logger.debug(f"Loading subtree {str(path)}") path_str = str(self.relative_to_project(path)) - resolved_path_str = str(path.resolve()) with self.db.get_connection() as conn: cursor = conn.cursor() - # pull and hydrate structs if path_str != ".": cursor.execute("SELECT * FROM structs WHERE path = ? OR path LIKE ? || '/%'", (path_str, path_str)) else: cursor.execute("SELECT * FROM structs") node_rows = cursor.fetchall() - node_ids = [str(row['id']) for row in node_rows] + node_ids = [str(row["id"]) for row in node_rows] for row in node_rows: struct_data = dict(row) - if struct_data.get('imports', None): - struct_data['imports'] = json.loads(struct_data['imports']) - if struct_data.get('dependency_names', None): - struct_data['dependency_names'] = json.loads(struct_data['dependency_names']) - if struct_data.get('inherits', None): - struct_data['inherits'] = json.loads(struct_data['inherits']) - if struct_data.get('enum_constants', None): - struct_data['enum_constants'] = json.loads(struct_data['enum_constants']) + if struct_data.get("imports", None): + struct_data["imports"] = json.loads(struct_data["imports"]) + if struct_data.get("dependency_names", None): + struct_data["dependency_names"] = json.loads(struct_data["dependency_names"]) + if struct_data.get("inherits", None): + struct_data["inherits"] = json.loads(struct_data["inherits"]) + if struct_data.get("enum_constants", None): + struct_data["enum_constants"] = json.loads(struct_data["enum_constants"]) builder = BaseBuilder(self) - struct_type = struct_data['type'] + struct_type = struct_data["type"] instance = builder.with_type(struct_type=struct_type).from_dict(struct_data) if instance: - instance.id = str(struct_data['id']) + instance.id = str(struct_data["id"]) self.add_struct(instance) - logger.debug(f"Found {len(node_rows)} structs in subtree {path_str}") - if not node_ids: return None placeholders = ",".join(["?"] * len(node_ids)) if path_str == ".": - sql = f""" - SELECT source_id, target_id, edge_type - FROM edges - WHERE edge_type = 'is_child_of' - """ + sql = f"SELECT source_id, target_id, edge_type FROM edges WHERE edge_type = 'is_child_of'" cursor.execute(sql) else: sql = f""" @@ -185,29 +172,20 @@ def load_filepath(self, path: Path): edge_rows = cursor.fetchall() - logger.debug(f"Found {len(edge_rows)} edges in subtree {path_str}") - - for source_id, target_id, edge_type in edge_rows: source_obj = self.id_map.get(str(source_id)) target_obj = self.id_map.get(str(target_id)) if not source_obj or not target_obj: - logger.warning(f"Edge references missing struct. Source ID: {source_id}, Target ID: {target_id}, Edge Type: {edge_type}") continue target_obj.add_child(source_obj) self.root = self.get_struct_by_uid(path_str) - - logger.debug(f"Loaded subtree {path_str} with root {self.root}") - return self.root def get_struct_by_uid(self, uid: str) -> Optional["BaseStruct"]: - # Check memory cache first if uid in self.uid_map: - # logger.debug(f"Cache hit for UID {uid}, returning memory object") return self.uid_map[uid] if uid in self.missing_uids: @@ -216,175 +194,111 @@ def get_struct_by_uid(self, uid: str) -> Optional["BaseStruct"]: if not self.use_cache or not self.db: return None - logger.debug(f"Attempting to retrieve struct and its children with UID {uid} from DB") + logger.debug(f"Attempting to retrieve {uid} and its children from DB") from tostr.core.builders import BaseBuilder - import json with self.db.get_connection() as conn: cursor = conn.cursor() - - # Target exact match alongside delimiter-specific prefix matches - # to retrieve the target struct and all hierarchical descendants (directories or code structs) if uid != ".": cursor.execute( - "SELECT * FROM structs WHERE uid = ? OR uid LIKE ? OR uid LIKE ?", - (uid, f"{uid}%", f"{uid}#%") + "SELECT * FROM structs WHERE uid = ? OR uid LIKE ? OR uid LIKE ? OR path = ?", + (uid, f"{uid}%", f"{uid}#%", uid) ) else: cursor.execute("SELECT * FROM structs") node_rows = cursor.fetchall() if not node_rows: - logger.debug(f"No structs found in DB matching UID {uid}") self.missing_uids.add(uid) return None - node_ids = [str(row['id']) for row in node_rows] + node_ids = [str(row["id"]) for row in node_rows] target_id = None for row in node_rows: struct_data = dict(row) - current_id = str(struct_data['id']) - - # Isolate the requested target struct ID for the final return - if struct_data['uid'] == uid: + current_id = str(struct_data["id"]) + if struct_data["uid"] == uid: target_id = current_id if current_id not in self.id_map: - if struct_data.get('imports', None): - struct_data['imports'] = json.loads(struct_data['imports']) - if struct_data.get('dependency_names', None): - struct_data['dependency_names'] = json.loads(struct_data['dependency_names']) - if struct_data.get('inherits', None): - struct_data['inherits'] = json.loads(struct_data['inherits']) - if struct_data.get('enum_constants', None): - struct_data['enum_constants'] = json.loads(struct_data['enum_constants']) + for field in ["imports", "dependency_names", "inherits", "enum_constants"]: + if struct_data.get(field): + struct_data[field] = json.loads(struct_data[field]) builder = BaseBuilder(self) - struct_type = struct_data['type'] - instance = builder.with_type(struct_type=struct_type).from_dict(struct_data) - + instance = builder.with_type(struct_type=struct_data["type"]).from_dict(struct_data) if instance: instance.id = current_id self.add_struct(instance) - # logger.debug(f"Created instance for struct with DB UID {struct_data['uid']} and type {struct_type}") - else: - logger.warning(f"Builder failed to create instance for struct with UID {struct_data['uid']} and type {struct_type}") if not node_ids: return None - # Fetch and connect edges for the loaded subset placeholders = ",".join(["?"] * len(node_ids)) - sql = f""" - SELECT source_id, target_id, edge_type - FROM edges - WHERE (source_id IN ({placeholders}) - OR target_id IN ({placeholders})) - AND edge_type = 'is_child_of' - """ - - params = node_ids + node_ids - cursor.execute(sql, params) + sql = f"SELECT source_id, target_id, edge_type FROM edges WHERE (source_id IN ({placeholders}) OR target_id IN ({placeholders})) AND edge_type = 'is_child_of'" + cursor.execute(sql, node_ids + node_ids) edge_rows = cursor.fetchall() for _source_id, _target_id, edge_type in edge_rows: source_obj = self.id_map.get(str(_source_id)) target_obj = self.id_map.get(str(_target_id)) - - if not source_obj or not target_obj: - continue - - target_obj.add_child(source_obj) + if source_obj and target_obj: + target_obj.add_child(source_obj) - struct = self.id_map.get(target_id) if target_id else None - - if struct is None: - logger.warning(f"Struct with DB UID {uid} was not found in memory cache after DB retrieval. Target ID: {target_id}") - return struct + return self.id_map.get(target_id) def get_struct_by_id(self, id: str) -> Optional["BaseStruct"]: id_str = str(id) - - # Check memory cache first - if hasattr(self, 'id_map') and id_str in self.id_map: + if id_str in self.id_map: return self.id_map[id_str] if not self.use_cache or not self.db: return None - # Fetch UID from db to execute localized subtree retrieval with self.db.get_connection() as conn: row = conn.execute("SELECT uid FROM structs WHERE id = ?", (id_str,)).fetchone() - logger.debug(f"Queried DB for struct with id {id_str}, got uid: {row[0]}") if not row: return None - target_uid = row[0] return self.get_struct_by_uid(target_uid) - def is_stale(self, struct: BaseCodeStruct | str) -> bool: + def is_stale(self, struct: "BaseStruct") -> bool: if not self.db: - raise RuntimeError("Cannot check for stale structs if SqLiteCache not provided.") - if isinstance(struct, str): - if struct not in self.uid_map: - raise KeyError(f"Could not find struct with uid {struct}") - struct = self.uid_map[struct] - - with self.db.get_connection() as conn: - row = conn.execute("SELECT diff_hash FROM structs WHERE uid = ?", (struct.uid,)).fetchone() - if not row or row[0] != struct.diff_hash: - return True - return False + return True - def update_cached_description(self, struct: BaseStruct | str): + return False + + def update_cached_description(self, struct: "BaseStruct"): if not self.db: - raise RuntimeError("Cannot check for stale structs if SqLiteCache not provided.") - if isinstance(struct, str): - if struct not in self.uid_map: - raise KeyError(f"Could not find struct with uid {struct}") - struct = self.uid_map[struct] - + raise RuntimeError("SqLiteCache not provided.") with self.db.get_connection() as conn: conn.execute("UPDATE structs SET description = ? WHERE uid = ?", (struct.description, struct.uid)) conn.commit() - - def save_struct_to_cache(self, struct: BaseStruct | str): - """ Saves a struct to the SQLite cache """ + def save_struct_to_cache(self, struct: "BaseStruct"): if not self.db: - raise RuntimeError("Cannot save to cache if SqLiteCache not provided.") - if isinstance(struct, str): - if struct not in self.uid_map: - raise KeyError(f"Could not find struct with uid {struct}") - struct = self.uid_map[struct] + raise RuntimeError("SqLiteCache not provided.") data = struct.to_dict() - target_uid = data.pop("uid") set_clause = ", ".join([f"{k} = ?" for k in data.keys()]) node_sql = f"UPDATE structs SET {set_clause} WHERE uid = ?" node_params = list(data.values()) + [target_uid] edges = list(struct.edges) - with self.db.get_connection() as conn: conn.execute(node_sql, node_params) - - # Clear all existing edges touching this node to prevent ghosts conn.execute("DELETE FROM edges WHERE source_id = ?", (struct.id,)) - - # Bulk insert fresh edges if edges: conn.executemany("INSERT INTO edges (source_id, target_id, edge_type) VALUES (?, ?, ?)", edges) conn.commit() - + def save_to_cache(self, stale: bool = False): - """Saves the entire AST to the SQLite cache.""" if not self.db: - raise RuntimeError("Cannot save to cache if SqLiteCache not provided.") + raise RuntimeError("SqLiteCache not provided.") parsed_ids = [(node.id,) for node in self.uid_map.values()] grouped_nodes = defaultdict(list) @@ -399,38 +313,22 @@ def serialize_for_db(value): for node in self.uid_map.values(): data_dict = node.to_dict() - if stale and data_dict.get("description", None): - data_dict["description"] = f"[STALE] {data_dict['description']}" + if stale and data_dict.get("description"): + data_dict["description"] = f"[STALE] {data_dict["description"]}" - # tuple of keys as the group identifier column_footprint = tuple(data_dict.keys()) - grouped_nodes[column_footprint].append(data_dict) all_edges.update(node.edges) with self.db.get_connection() as conn: for columns_tuple, dict_list in grouped_nodes.items(): - columns = ", ".join(columns_tuple) placeholders = ", ".join(["?"] * len(columns_tuple)) node_sql = f"INSERT OR REPLACE INTO structs ({columns}) VALUES ({placeholders})" - - # Fetching the specific value for each column from the dictionary - node_values = [ - tuple(serialize_for_db(n.get(col)) for col in columns_tuple) - for n in dict_list - ] - + node_values = [tuple(serialize_for_db(n.get(col)) for col in columns_tuple) for n in dict_list] conn.executemany(node_sql, node_values) - conn.executemany( - "DELETE FROM edges WHERE source_id = ?", - parsed_ids - ) - + conn.executemany("DELETE FROM edges WHERE source_id = ?", parsed_ids) if all_edges: - conn.executemany( - "INSERT INTO edges (source_id, target_id, edge_type) VALUES (?, ?, ?)", - list(all_edges) - ) - conn.commit() \ No newline at end of file + conn.executemany("INSERT INTO edges (source_id, target_id, edge_type) VALUES (?, ?, ?)", list(all_edges)) + conn.commit() diff --git a/src/tostr/core/serializer.py b/src/tostr/core/serializer.py index 7d183b9..ccd34b2 100644 --- a/src/tostr/core/serializer.py +++ b/src/tostr/core/serializer.py @@ -21,6 +21,8 @@ class tost: def dump_skeleton( cls, obj: "BaseStruct", + files_only: bool = True, + depth: int = 7, # indent: int = 0, pretty: bool = True ) -> str: @@ -31,18 +33,27 @@ def dump_skeleton( parts.append(header_str) if obj.files: - for f in obj.files: - parts.append(cls.dump_skeleton(f, pretty=pretty)) + if depth == 0: + parts.append(f"{indent_str}... ({len(obj.files)} files)") + else: + for f in obj.files: + parts.append(cls.dump_skeleton(f, files_only=files_only, depth=depth-1, pretty=pretty)) if obj.directories: - for d in obj.directories: - if d is obj: - logger.warning(f"Skipping dumping directory {d} as it is the same as its parent {obj}, likely to avoid circular reference.") - continue - parts.append('\n' + cls.dump_skeleton(d, pretty=pretty)) - if obj.classes: - for c in obj.classes: - parts.append(cls.dump_skeleton(c, pretty=pretty)) - + if depth == 0: + parts.append(f"{indent_str}... ({len(obj.directories)} directories)") + else: + for d in obj.directories: + if d is obj: + logger.warning(f"Skipping dumping directory {d} as it is the same as its parent {obj}, likely to avoid circular reference.") + continue + parts.append(cls.dump_skeleton(d, files_only=files_only, depth=depth-1, pretty=pretty)) + if obj.classes and not files_only: + if depth == 0: + parts.append(f"{indent_str}... ({len(obj.classes)} classes)") + else: + for c in obj.classes: + parts.append(cls.dump_skeleton(c, files_only=files_only, depth=depth-1, pretty=pretty)) + return textwrap.indent("\n".join(parts), indent_str) @classmethod diff --git a/src/tostr/exceptions.py b/src/tostr/exceptions.py index 0f4dc00..5cbe391 100644 --- a/src/tostr/exceptions.py +++ b/src/tostr/exceptions.py @@ -1,22 +1,22 @@ -class ToasterError(Exception): - """Base exception for all Toaster domain errors.""" +class TostrError(Exception): + """Base exception for all Tostr domain errors.""" pass -class StructNotFoundError(ToasterError): +class StructNotFoundError(TostrError): pass -class APIKeyError(ToasterError): +class APIKeyError(TostrError): pass -class ResolveError(ToasterError): +class ResolveError(TostrError): pass -class LanguageNotSupportedError(ToasterError): +class LanguageNotSupportedError(TostrError): pass -class TargetFileNotFoundError(ToasterError): +class TargetFileNotFoundError(TostrError): pass -class DatabaseNotFoundError(ToasterError): +class DatabaseNotFoundError(TostrError): pass \ No newline at end of file diff --git a/src/tostr/languages/java/builders.py b/src/tostr/languages/java/builders.py index 2e42f39..c09b8cb 100644 --- a/src/tostr/languages/java/builders.py +++ b/src/tostr/languages/java/builders.py @@ -1,5 +1,6 @@ from tree_sitter import Parser, Node, Query, QueryCursor from pathlib import Path +import hashlib from tostr.core.registry import Registry from tostr.languages.java.language import JAVA_LANGUAGE @@ -29,6 +30,7 @@ def from_path(self, path: Path, parent: BaseStruct=None) -> BaseFile: with open(path, "rb") as f: body_bytes = f.read() file_obj.body = body_bytes.decode("utf-8") + file_obj.diff_hash = hashlib.md5(body_bytes).hexdigest() # Fallback until distributed hash is calculated parser = Parser(JAVA_LANGUAGE) tree = parser.parse(body_bytes) @@ -136,6 +138,7 @@ def from_node(self, node: Node, parent: BaseStruct=None) -> BaseClass: # BaseCodeStruct signature=signature, body=body, + diff_hash=hashlib.md5(node.text).hexdigest(), start_line=node.start_point[0], end_line=node.end_point[0], node=node, @@ -257,6 +260,7 @@ def from_node(self, node: Node, parent: BaseStruct=None) -> BaseMethod: # BaseCodeStruct signature=signature, body=body, + diff_hash=hashlib.md5(node.text).hexdigest(), start_line=node.start_point[0], end_line=node.end_point[0], node=node, @@ -315,6 +319,7 @@ def from_node(self, node: Node, parent: BaseStruct=None) -> BaseField: # BaseCodeStruct signature=signature, body=body, + diff_hash=hashlib.md5(node.text).hexdigest(), start_line=node.start_point[0], end_line=node.end_point[0], node=node, diff --git a/src/tostr/llm/base.py b/src/tostr/llm/base.py index 29da770..e24804a 100644 --- a/src/tostr/llm/base.py +++ b/src/tostr/llm/base.py @@ -1,5 +1,6 @@ import asyncio import json +import time from abc import ABC, abstractmethod from typing import Any from pydantic import BaseModel, Field @@ -64,6 +65,8 @@ async def describe_class(self, class_obj: Any, imports: list[str]) -> LLMRespons input_data_string = json.dumps(input_data) logger.debug(f"Generating Description for {class_obj.uid}...") + + start_time = time.perf_counter() max_retries = 3 base_delay = 2 @@ -86,6 +89,10 @@ async def describe_class(self, class_obj: Any, imports: list[str]) -> LLMRespons except (ValueError, TypeError): continue + end_time = time.perf_counter() + elapsed_time = end_time - start_time + logger.debug(f"Finished describing {class_obj.uid} in {elapsed_time:.4f} seconds") + return result except Exception as e: error_str = str(e) diff --git a/src/tostr/mcp.py b/src/tostr/mcp.py deleted file mode 100644 index ca55bed..0000000 --- a/src/tostr/mcp.py +++ /dev/null @@ -1,148 +0,0 @@ -import threading -import asyncio -from pathlib import Path -from fastmcp import FastMCP -from loguru import logger -import os - -from tostr.exceptions import ToasterError - -from tostr.commands import ( - init_async, - inspect_async, - skeleton_async, - watch_async, - clean_db -) -from tostr.core.utils.logger import configure_mcp_logging - -_is_initialized = False -_current_project_dir = None - -mcp = FastMCP("Toaster") - -# --- THE SYNCHRONOUS BRIDGE --- -def _run_watcher_thread(target_path: Path): - """ - Sets up an isolated async environment for the background thread, - then runs your watch_async loop inside it. - """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - # This calls your exact watch_async function! - loop.run_until_complete(watch_async(target_path)) - except Exception as e: - logger.exception(f"Fatal error in background watcher: {e}") - finally: - loop.close() - logger.info("Background watcher shut down cleanly.") - -@mcp.tool() -async def init(workspace_path: str, use_cache: bool = True, ignore: str = None) -> str: - """ - -- MUST BE RUN BEFORE ANY OTHER TOOL -- - Initializes the Toaster MCP server for a specific project workspace. - - Args: - workspace_path: The ABSOLUTE path to the project workspace. DO NOT use '.' or relative paths. If you only have a relative path, you must determine the absolute path of the current workspace first. - use_cache: Whether to use the existing AST cache. - ignore: Add a default ignore template to the project folder (e.g., 'java', 'default'). - """ - - target_path = Path(workspace_path) - - if not target_path.is_absolute(): - return (f"Error: workspace_path must be an absolute path. You provided '{workspace_path}'. " - f"Please determine the absolute path of the current workspace and try again.") - - target_path = target_path.resolve() - - try: - os.chdir(target_path) - except FileNotFoundError: - return f"Fatal Error: Workspace path does not exist: {target_path}" - - global _is_initialized, _current_project_dir - project_dir = target_path - - if _is_initialized and _current_project_dir == project_dir: - return f"Status: Already initialized for {project_dir}." - - try: - configure_mcp_logging(project_dir) - - await init_async(project_dir, use_cache, ignore) - - watcher_thread = threading.Thread( - target=_run_watcher_thread, - args=(project_dir,), - daemon=True - ) - watcher_thread.start() - - _is_initialized = True - _current_project_dir = project_dir - - return f"Success: Toaster initialized. Cache is built at {project_dir}/.tostr/cache.db. Background watcher is now actively listening on {project_dir}" - - except Exception as e: - return f"Fatal Error Initializing Toaster: {str(e)}" - -@mcp.tool() -async def inspect(id: str, include_body: bool = False) -> str: - """ - Output the AST details and code for a specific struct ID. - Use this when you need the full implementation details of a specific function or class. - - Args: - id: The unique Toaster ID of the struct to inspect. - include_body: Include the raw code body in the output. - """ - global _is_initialized, _current_project_dir - - if not _is_initialized: - return "Error: Toaster is not initialized. You must call 'init' with the absolute workspace path before querying the database." - - try: - result = await inspect_async(id, _current_project_dir, include_body, pretty=False) - return str(result) - except ToasterError as e: - return f"Error: {e}" - - -@mcp.tool() -async def clean(workspace_path: str) -> str: - """ - Clean the SQLite database for a specific workspace. - """ - try: - project_dir = Path(workspace_path).resolve() - clean_db(project_dir) - return f"Success: Database cleaned for {project_dir}." - except Exception as e: - return f"Error: {e}" - -@mcp.tool() -async def skeleton(subpath: str) -> str: - """ - Output the .tost skeleton format for all files matching a specific subpath. - Use this to understand the high-level architecture, classes, and function signatures of a file or directory without reading the full code. - - Args: - subpath: File or directory path relative to the project root to generate a skeleton for. - """ - global _is_initialized, _current_project_dir - - if not _is_initialized: - return "Error: Toaster is not initialized. You must call 'init' with the absolute workspace path before querying the database." - - try: - result = await skeleton_async(subpath, _current_project_dir, pretty=False) - return str(result) - except ToasterError as e: - return f"Error: {e}" - -if __name__ == "__main__": - mcp.run() \ No newline at end of file diff --git a/src/tostr/server.py b/src/tostr/server.py new file mode 100644 index 0000000..5c082d0 --- /dev/null +++ b/src/tostr/server.py @@ -0,0 +1,183 @@ +import threading +import asyncio +from pathlib import Path +from fastmcp import FastMCP +from loguru import logger +import os + +from tostr.exceptions import TostrError + +from tostr.commands import ( + init_async, + inspect_async, + skeleton_async, + watch_async, + clean_db +) +from tostr.core.utils.logger import configure_mcp_logging + +class MCPSession: + def __init__(self): + self.is_initialized = False + self.project_dir = None + self.watcher_thread = None + self.watcher_loop = None + self.stop_event = None + + def stop_watcher(self): + """Cleanly stops the background watcher thread if it exists.""" + if self.watcher_loop and self.stop_event: + logger.info("Signaling background watcher to stop...") + self.watcher_loop.call_soon_threadsafe(self.stop_event.set) + + if self.watcher_thread: + # We wait briefly for it to shut down + self.watcher_thread.join(timeout=2) + if self.watcher_thread.is_alive(): + logger.warning("Watcher thread did not shut down in time.") + + self.watcher_thread = None + self.watcher_loop = None + self.stop_event = None + + def start_watcher(self, target_path: Path): + """Starts the background watcher thread.""" + self.stop_watcher() + + self.watcher_thread = threading.Thread( + target=self._run_watcher_thread, + args=(target_path,), + daemon=True + ) + self.watcher_thread.start() + + def _run_watcher_thread(self, target_path: Path): + """ + Sets up an isolated async environment for the background thread, + then runs your watch_async loop inside it. + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self.watcher_loop = loop + self.stop_event = asyncio.Event() + + try: + loop.run_until_complete(watch_async(target_path, stop_event=self.stop_event)) + except Exception as e: + logger.exception(f"Fatal error in background watcher: {e}") + finally: + loop.close() + logger.info("Background watcher shut down cleanly.") + +session = MCPSession() +mcp = FastMCP("Tostr") + +@mcp.tool() +async def init(workspace_path: str, use_cache: bool = True, ignore: str = None) -> str: + """ + -- MUST BE RUN BEFORE ANY OTHER TOOL -- + Initializes the Tostr MCP server for a specific project workspace. + By default, it will attempt to sync with an existing database if one is found. + + Args: + workspace_path: The ABSOLUTE path to the project workspace. DO NOT use '.' or relative paths. If you only have a relative path, you must determine the absolute path of the current workspace first. + use_cache: Whether to use the existing AST cache. If False, forces a full re-parse. + ignore: Add a default ignore template to the project folder (e.g., 'java', 'default'). + """ + + target_path = Path(workspace_path) + + if not target_path.is_absolute(): + return (f"Error: workspace_path must be an absolute path. You provided '{workspace_path}'. " + f"Please determine the absolute path of the current workspace and try again.") + + target_path = target_path.resolve() + + if not target_path.exists(): + return f"Fatal Error: Workspace path does not exist: {target_path}" + + # Check if we are already initialized for this path + if session.is_initialized and session.project_dir == target_path and use_cache: + return f"Status: Already initialized for {target_path}. Set use_cache=False to force a re-parse." + + db_path = target_path / ".tostr" / "cache.db" + + try: + configure_mcp_logging(target_path) + + # Auto-sync logic: If DB exists and we are using cache, just latch on. + if db_path.exists() and use_cache: + session.project_dir = target_path + session.start_watcher(target_path) + session.is_initialized = True + return f"Success: Tostr synced with existing database at {target_path}. Background watcher active." + + # Otherwise, perform full initialization/parse + await init_async(target_path, use_cache, ignore) + + session.project_dir = target_path + session.start_watcher(target_path) + session.is_initialized = True + + return f"Success: Tostr initialized and parsed. Cache is built at {db_path}. Background watcher is now actively listening on {target_path}" + + except Exception as e: + return f"Fatal Error Initializing Tostr: {str(e)}" + +@mcp.tool() +async def inspect(id: str, include_body: bool = False) -> str: + """ + Output the AST details and code for a specific struct ID. + Use this when you need the full implementation details of a specific function or class. + + Args: + id: The unique Tostr ID of the struct to inspect. + include_body: Include the raw code body in the output. + """ + if not session.is_initialized: + return "Error: Tostr is not initialized. You must call 'init' or 'sync' with the absolute workspace path before querying the database." + + try: + result = await inspect_async(id, session.project_dir, include_body, pretty=False) + return str(result) + except TostrError as e: + return f"Error: {e}" + +@mcp.tool() +async def clean(workspace_path: str) -> str: + """ + Clean the SQLite database for a specific workspace and reset the server state if it matches. + """ + try: + project_dir = Path(workspace_path).resolve() + clean_db(project_dir) + + if session.project_dir == project_dir: + session.stop_watcher() + session.is_initialized = False + session.project_dir = None + + return f"Success: Database cleaned for {project_dir}." + except Exception as e: + return f"Error: {e}" + +@mcp.tool() +async def skeleton(subpath: str) -> str: + """ + Output the .tost skeleton format for all files matching a specific subpath. + Use this to understand the high-level architecture, classes, and function signatures of a file or directory without reading the full code. + + Args: + subpath: File or directory path relative to the project root to generate a skeleton for. + """ + if not session.is_initialized: + return "Error: Tostr is not initialized. You must call 'init' or 'sync' with the absolute workspace path before querying the database." + + try: + result = await skeleton_async(subpath, session.project_dir, pretty=False) + return str(result) + except TostrError as e: + return f"Error: {e}" + +if __name__ == "__main__": + mcp.run() diff --git a/tests/languages/java/builder/test_java_class_builder.py b/tests/languages/java/builder/test_java_class_builder.py index 897c496..300cd04 100644 --- a/tests/languages/java/builder/test_java_class_builder.py +++ b/tests/languages/java/builder/test_java_class_builder.py @@ -23,9 +23,8 @@ def mock_registry(): @pytest.fixture def mock_parent_file(): """Mocks the BaseFile parent needed for UID generation.""" - mock_file = MagicMock(spec=BaseFile) - mock_file.uid = "src/main/java/com/tostr/DataProcessor.java" - mock_file.path = Path("src/main/java/com/tostr/DataProcessor.java") + mock_file = MagicMock(spec=BaseFile, uid="com/tostr/DataProcessor.java", path=Path("src/main/java/com/tostr/DataProcessor.java")) + mock_file.package = "com.tostr" # Tell isinstance(parent, BaseFile) to return True mock_file.__class__ = BaseFile return mock_file @@ -65,7 +64,7 @@ class InnerHelper {} # 1. Test BaseStruct Properties assert class_obj.name == "DataProcessor" - assert class_obj.uid == "src/main/java/com/tostr/DataProcessor.java#DataProcessor" + assert class_obj.uid == f"{mock_parent_file.package}.DataProcessor" assert class_obj.parent == mock_parent_file # 2. Test Signature Extraction @@ -90,4 +89,4 @@ class InnerHelper {} struct_names = [s.name for s in registered_structs] assert "data" in struct_names # The field assert "process" in struct_names # The method - assert "InnerHelper" in struct_names # The inner class \ No newline at end of file + assert "InnerHelper" in struct_names # The inner class diff --git a/tests/languages/java/builder/test_java_field_builder.py b/tests/languages/java/builder/test_java_field_builder.py index 7d20f57..209ff2a 100644 --- a/tests/languages/java/builder/test_java_field_builder.py +++ b/tests/languages/java/builder/test_java_field_builder.py @@ -5,7 +5,7 @@ from tree_sitter import Parser from tostr.languages.java.language import JAVA_LANGUAGE from tostr.core.registry import Registry -from tostr.core.models import BaseClass, BaseField +from tostr.core.models import BaseFile, BaseClass, BaseField from tostr.languages.java.builders import JavaFieldBuilder @pytest.fixture(scope="session") @@ -22,10 +22,12 @@ def mock_registry(): @pytest.fixture def mock_parent_class(): """Mocks the BaseClass parent needed for UID generation.""" - mock_cls = MagicMock(spec=BaseClass) - mock_cls.uid = "src/main/java/com/tostr/Constants.java#Constants" - mock_cls.path = Path("src/main/java/com/tostr/Constants.java") - mock_cls.__class__ = BaseClass + mock_cls = MagicMock( + spec=BaseClass, + uid="com.example.TestClass", + name="TestClass", + parent=MagicMock(spec=BaseFile, uid="com/tostr/Constants.java", package="com.tostr") + ) return mock_cls def test_java_field_builder_extracts_fields(java_parser, mock_registry, mock_parent_class): @@ -44,7 +46,7 @@ class Constants { # Find the field nodes inside the class body field_nodes = [] class_node = tree.root_node.children[0] - body_node = class_node.child_by_field_name('body') + body_node = class_node.child_by_field_name("body") for child in body_node.children: if child.type == "field_declaration": @@ -65,12 +67,12 @@ class Constants { # Signature Tests (Ensuring the comment was skipped and order is correct) expected_tau_sig = "@Serialized public static final double TAU" - assert tau_field.signature == expected_tau_sig, f"Expected '{expected_tau_sig}', got '{tau_field.signature}'" + assert tau_field.signature == expected_tau_sig, f"Expected \'{expected_tau_sig}\\' , got \'{tau_field.signature}\'" assert "This comment should be ignored" not in tau_field.signature # UID Test (Ensuring NO type information is appended to fields) - expected_tau_uid = "src/main/java/com/tostr/Constants.java#Constants.TAU" - assert tau_field.uid == expected_tau_uid, f"Expected '{expected_tau_uid}', got '{tau_field.uid}'" + expected_tau_uid = "com.example.TestClass.TAU" + assert tau_field.uid == expected_tau_uid, f"Expected \'{expected_tau_uid}\\' , got \'{tau_field.uid}\'" # --- TEST 2: The Generic Field --- @@ -83,5 +85,5 @@ class Constants { expected_users_sig = "private List activeUsers" assert users_field.signature == expected_users_sig - expected_users_uid = "src/main/java/com/tostr/Constants.java#Constants.activeUsers" - assert users_field.uid == expected_users_uid \ No newline at end of file + expected_users_uid = "com.example.TestClass.activeUsers" + assert users_field.uid == expected_users_uid diff --git a/tests/languages/java/builder/test_java_method_builder.py b/tests/languages/java/builder/test_java_method_builder.py index f3156fc..34e7ca5 100644 --- a/tests/languages/java/builder/test_java_method_builder.py +++ b/tests/languages/java/builder/test_java_method_builder.py @@ -5,7 +5,7 @@ from tree_sitter import Parser from tostr.languages.java.language import JAVA_LANGUAGE from tostr.core.registry import Registry -from tostr.core.models import BaseClass, BaseMethod +from tostr.core.models import BaseFile, BaseClass, BaseMethod from tostr.languages.java.builders import JavaMethodBuilder @pytest.fixture(scope="session") @@ -22,12 +22,14 @@ def mock_registry(): @pytest.fixture def mock_parent_class(): """Mocks the BaseClass parent needed for UID generation.""" - mock_cls = MagicMock(spec=BaseClass) - # The UID of a class uses # - mock_cls.uid = "src/main/java/com/tostr/Mathf.java#Mathf" - mock_cls.path = Path("src/main/java/com/tostr/Mathf.java") - # Make sure isinstance(parent, BaseFile) returns False - mock_cls.__class__ = BaseClass + mock_cls = MagicMock( + spec=BaseClass, + uid="com.example.OuterClass", + name="OuterClass", + parent=MagicMock(spec=BaseFile, package="com.example"), + imports=["java.util.*", "com.example.dep.Dependency"], + fields=[MagicMock(name="myService", field_type="MyService")] + ) return mock_cls def test_java_method_builder_extracts_complex_method(java_parser, mock_registry, mock_parent_class): @@ -48,7 +50,7 @@ class Mathf { # 1. Find the method nodes inside the class body method_nodes = [] class_node = tree.root_node.children[0] - body_node = class_node.child_by_field_name('body') + body_node = class_node.child_by_field_name("body") for child in body_node.children: if child.type == "method_declaration": @@ -63,7 +65,7 @@ class Mathf { method_obj = builder.from_node(complex_method_node, parent=mock_parent_class) # Core Properties - assert method_obj.name == "processData", f"Expected 'processData', got {method_obj.name}" + assert method_obj.name == "processData", f"Expected \'processData\', got {method_obj.name}" assert method_obj.parent == mock_parent_class # Signature Tests @@ -78,7 +80,7 @@ class Mathf { assert method_obj.arity == 2, f"Expected arity 2, got {method_obj.arity}" # UID Test (Crucial for method overloading support) - expected_uid = "src/main/java/com/tostr/Mathf.java#Mathf.processData(int, String)" + expected_uid = "com.example.OuterClass.processData(int, String)" assert method_obj.uid == expected_uid, f"Expected {expected_uid}, got {method_obj.uid}" @@ -91,5 +93,5 @@ class Mathf { assert "void ping()" in simple_obj.signature # Check empty parameter UID - expected_simple_uid = "src/main/java/com/tostr/Mathf.java#Mathf.ping()" - assert simple_obj.uid == expected_simple_uid, f"Expected {expected_simple_uid}, got {simple_obj.uid}" \ No newline at end of file + expected_simple_uid = "com.example.OuterClass.ping()" + assert simple_obj.uid == expected_simple_uid, f"Expected {expected_simple_uid}, got {simple_obj.uid}" diff --git a/tests/languages/java/test_java_dependencies.py b/tests/languages/java/test_java_dependencies.py index da1538c..d4a1a6f 100644 --- a/tests/languages/java/test_java_dependencies.py +++ b/tests/languages/java/test_java_dependencies.py @@ -29,8 +29,8 @@ def test_java_dependency_parsing(tmp_path, registry): method1 = [m for m in registry.methods if m.name == "method1"][0] # Verify dependency names and arities are parsed - assert ("method2", 0) in method1.dependency_names - assert ("method3", 2) in method1.dependency_names + assert ("method2", 0, None, False) in method1.dependency_names + assert ("method3", 2, None, False) in method1.dependency_names def test_java_dependency_resolution_local(tmp_path, registry): """Tests resolution of local method calls (same class).""" @@ -98,4 +98,4 @@ def test_java_dependency_resolution_imported(tmp_path, registry): method_a = [m for m in registry.methods if m.name == "methodA"][0] # This tests the "IMPORTED" logic in BaseMethod.resolve_dependencies - assert method_a in method_b.outbound_dependencies + assert method_a.parent in method_b.outbound_dependencies diff --git a/tests/llm/test_llm_client.py b/tests/llm/test_llm_client.py index 8c513d7..aff5e4d 100644 --- a/tests/llm/test_llm_client.py +++ b/tests/llm/test_llm_client.py @@ -22,9 +22,9 @@ async def test_llm_client_data_flow(): client = LLMClient(strategy) # Create a mock method - mock_method = MagicMock(spec=BaseMethod) + mock_method = MagicMock(spec=BaseMethod, uid="mock.method", name="method") + mock_method.parent = MagicMock(spec=BaseFile, imports=[]) # Mock the parent file mock_method.signature = "void test()" - mock_method.uid = "test_uid" mock_method.description = "" # Create a mock class @@ -46,8 +46,7 @@ async def test_models_resolve_description_data_flow(): Test the full data flow from BaseClass.resolve_description_async through LLMClient. """ registry = MagicMock(spec=Registry) - parent_file = MagicMock(spec=BaseFile) - parent_file.imports = [] + parent_file = MagicMock(spec=BaseFile, imports=[], package="com.example") # Create a real BaseClass instance cls = BaseClass(name="TestClass", uid="TestClass") @@ -93,9 +92,12 @@ async def describe_class(self, input_data_string: str, system_instruction: str) strategy = RetryStrategy(api_key="test", model_name="test") client = LLMClient(strategy) - mock_class = MagicMock(spec=BaseClass) - mock_class.methods = [] - mock_class.skeletonize.return_value = "code" + mock_class = MagicMock( + spec=BaseClass, + uid="mock.class", + methods=[], + skeletonize=MagicMock(return_value="code") + ) # We need to mock asyncio.sleep to avoid waiting during tests with MagicMock() as mock_sleep: diff --git a/tests/testcode/MRILib/.tostr/cache.db b/tests/testcode/MRILib/.tostr/cache.db index 466e2bc..44f85dd 100644 Binary files a/tests/testcode/MRILib/.tostr/cache.db and b/tests/testcode/MRILib/.tostr/cache.db differ