Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/tostr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import Annotated
from loguru import logger

from tostr.exceptions import ToasterError
from tostr.exceptions import TostrError

from tostr.commands import init_async, inspect_async, skeleton_async, watch_async, clean_db

from tostr.mcp import mcp
from tostr.server import mcp

from tostr.core.utils.logger import configure_cli_logging

Expand Down Expand Up @@ -67,7 +67,7 @@ def watch(
configure_cli_logging(debug)
try:
asyncio.run(watch_async(path))
except ToasterError as e:
except TostrError as e:
typer.secho(f"❌ Error: {e}", fg="red", err=True)
raise typer.Exit(code=1)

Expand All @@ -94,7 +94,7 @@ def clean(
configure_cli_logging(debug)
try:
clean_db(path)
except ToasterError as e:
except TostrError as e:
typer.secho(f"❌ Error: {e}", fg="red", err=True)
raise typer.Exit(code=1)

Expand Down Expand Up @@ -137,7 +137,7 @@ def init(
start_time = time.perf_counter()
try:
asyncio.run(init_async(path, use_cache, ignore))
except ToasterError as e:
except TostrError as e:
typer.secho(f"❌ Error: {e}", fg="red", err=True)
raise typer.Exit(code=1)

Expand Down Expand Up @@ -190,7 +190,7 @@ def inspect(
try:
result = asyncio.run(inspect_async(id, path, include_body=include_body, pretty=pretty))
print(result)
except ToasterError as e:
except TostrError as e:
typer.secho(f"❌ Error: {e}", fg="red", err=True)
raise typer.Exit(code=1)

Expand Down Expand Up @@ -237,7 +237,7 @@ def skeleton(
try:
result = asyncio.run(skeleton_async(subpath, path, pretty=pretty))
print(result)
except ToasterError as e:
except TostrError as e:
typer.secho(f"❌ Error: {e}", fg="red", err=True)
raise typer.Exit(code=1)
end_time = time.perf_counter()
Expand Down
4 changes: 2 additions & 2 deletions src/tostr/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ async def skeleton_async(subpath: str, project_path: Path, pretty: bool = True):

active_tasks = {}

async def watch_async(target_path: Path):
async def watch_async(target_path: Path, stop_event: asyncio.Event = None):
llm = get_llm_client()

logger.info("Starting Listener")
try:
async for changes in awatch(target_path):
async for changes in awatch(target_path, stop_event=stop_event):
for change_type, path in changes:
path = Path(path).relative_to(target_path)
if ".tostr" in str(path):
Expand Down
1 change: 1 addition & 0 deletions src/tostr/core/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def from_dict(self, d: dict) -> BaseFile:
description=d.get("description", ""),
imports=d.get("imports", []),
body=d.get("body", ""),
diff_hash=d.get("diff_hash", ""),
package=d.get("package", ""),
_inbound_dependency_strings=json.loads(d.get("inbound_dependency_strings", [])),
_outbound_dependency_strings=json.loads(d.get("outbound_dependency_strings", [])),
Expand Down
44 changes: 29 additions & 15 deletions src/tostr/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ def edges(self):
elif isinstance(self.parent, str):
edges.add((self.id, self.parent, "is_child_of"))

# for child_set in self.children.values():
# for child in child_set:
# edges.update(child.edges)
# edges.add((child.id, self.id, "contains"))

return edges

@property
Expand Down Expand Up @@ -143,6 +138,15 @@ def set_parent(self, parent: "BaseStruct"):
logger.warning(f"Attempted to set parent of {self} to itself. Skipping to avoid circular reference.")
return
self.parent = parent

def calculate_distributed_hash(self):
"""Calculates diff_hash based on direct children's hashes."""
if not self.all_children:
return

child_hashes = sorted([child.diff_hash for child in self.all_children if getattr(child, "diff_hash", None)])
if child_hashes:
self.diff_hash = hashlib.md5("".join(child_hashes).encode("utf-8")).hexdigest()

def add_dependency(self, target: "BaseStruct"):
self.outbound_dependencies.add(target)
Expand Down Expand Up @@ -175,7 +179,6 @@ async def resolve_description_async(self, llm: "LLMClient", visited: set[str] =
@classmethod
def from_dict(cls, d: dict):
data = d.copy()
# REMOVE all init=False here
id = data.pop("id", None)
instance = cls(**data)
if id:
Expand Down Expand Up @@ -213,6 +216,7 @@ def __str__(self):
@dataclass(eq=False)
class Directory(BaseStruct):
_IDPREFIX: ClassVar[str] = "D"
diff_hash: str = ""

def __init__(self, path, registry=None, parent=None, uid=None):
uid = uid or str(path)
Expand All @@ -233,11 +237,11 @@ def parse_children(self):

for path in full_path.glob("*"):
if self.registry.config.is_ignored(path):
logger.debug(f"Skipping '{path}' due to path ignore rules")
logger.debug(f"Skipping \'{path}\' due to path ignore rules")
continue
else:
if path.is_dir():
logger.debug(f"🔍 Parsing directory '{path}'")
logger.debug(f"🔍 Parsing directory \'{path}\'")
relative_path = self.registry.relative_to_project(path)
directory = Directory(path=relative_path, registry=self.registry, parent=self)
self.registry.add_struct(directory)
Expand All @@ -247,18 +251,26 @@ def parse_children(self):
logger.debug(f"Attempting to resolve builder for suffix {path.parts[-1]}")
try:
if self.registry.config.is_ignored(path):
logger.debug(f"Skipping '{path}' due to path ignore rules")
logger.debug(f"Skipping \'{path}\' due to path ignore rules")
continue
builder = StructBuilderProvider.get_builder(path.suffix, self.registry)
except LanguageNotSupportedError as e:
continue
instance = builder.build_file().from_path(path, parent=self)

# Calculate file hash from children if any exist
instance.calculate_distributed_hash()

self.registry.add_struct(instance)
self.add_child(instance)

# Calculate directory hash from its direct children
self.calculate_distributed_hash()

def to_dict(self) -> dict:
data = super().to_dict()
data["type"] = "Directory"
data["diff_hash"] = self.diff_hash
return data

@dataclass(eq=False)
Expand All @@ -268,6 +280,7 @@ class BaseFile(BaseStruct):
imports: List[str] = field(default_factory=list)
package: str = ""
body: str = ""
diff_hash: str = ""
node: "Node" = None

async def resolve_description_async(self, llm: "LLMClient", visited: set[str] = None):
Expand All @@ -279,8 +292,9 @@ def to_dict(self) -> dict:
data["type"] = "BaseFile"
data["imports"] = self.imports
data["body"] = self.body
data["diff_hash"] = self.diff_hash
return data

@dataclass(eq=False)
class BaseCodeStruct(BaseStruct):

Expand Down Expand Up @@ -329,7 +343,7 @@ def needs_description(self) -> bool:
def imports(self) -> List[str]:
return self.parent.imports

def resolve_type(self, type_name: str) -> Optional[BaseStruct]:
def resolve_type(self, type_name: str) -> Optional["BaseStruct"]:
"""Resolves a simple or scoped type name to a struct using package and imports."""
if not type_name: return None
if type_name in self._type_cache:
Expand All @@ -340,7 +354,7 @@ def resolve_type(self, type_name: str) -> Optional[BaseStruct]:

# 2. Same package
if not dep:
package = getattr(self.parent, 'package', None) if isinstance(self.parent, BaseFile) else None
package = getattr(self.parent, "package", None) if isinstance(self.parent, BaseFile) else None
if package:
dep = self.registry.get_struct_by_uid(f"{package}.{type_name}")

Expand Down Expand Up @@ -462,7 +476,7 @@ class BaseMethod(BaseCodeStruct):
# pass

def resolve_dependencies(self):
logger.info(f"Resolving dependencies for method {self.uid}")
logger.debug(f"Resolving dependencies for method {self.uid}")
for dep_info in self.dependency_names:
# Handle both old (name, arity) and new (name, arity, receiver, is_creation) formats
if len(dep_info) == 2:
Expand All @@ -483,7 +497,7 @@ def resolve_dependencies(self):
search_scope = self.parent.children if self.parent else self.children
for child_set in list(search_scope.values()):
for child in list(child_set):
if child.name == name and getattr(child, 'arity', -1) == arity:
if child.name == name and getattr(child, "arity", -1) == arity:
self.add_dependency(child)
resolved = True
break
Expand Down Expand Up @@ -512,7 +526,7 @@ def resolve_dependencies(self):
if self.parent._potential_parents_cache is None:
potential_parents = []
# Same package
package = getattr(self.parent.parent, 'package', None) if isinstance(self.parent.parent, BaseFile) else None
package = getattr(self.parent.parent, "package", None) if isinstance(self.parent.parent, BaseFile) else None
if package: potential_parents.append(f"{package}.*")

# All imports
Expand Down
68 changes: 19 additions & 49 deletions src/tostr/core/parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pathlib import Path
from abc import ABC
import asyncio
import hashlib
from loguru import logger

from tostr.core.models import BaseFile, Directory
from tostr.core.models import BaseFile, Directory, BaseStruct
from tostr.core.registry import Registry
from tostr.core.providers import StructBuilderProvider
from tostr.exceptions import LanguageNotSupportedError
Expand All @@ -13,13 +14,10 @@ def __init__(self, project_dir: str, llm=None, registry: Registry=None):
self.project_dir = project_dir
self.llm = llm
self.registry = registry
# self.path_ignore = ["venv", ".venv", "env", ".env", "build", "dist", "__pycache__", ".tostr", ".git"]

@property
def files(self):
if self._files: return self._files
self._files = self.registry.uid_map.values().filter(lambda x: isinstance(x, BaseFile))
return self._files
return [x for x in self.registry.uid_map.values() if isinstance(x, BaseFile)]

async def parse(self, subpath: Path = None):
if not subpath:
Expand All @@ -28,23 +26,19 @@ async def parse(self, subpath: Path = None):
subpath = Path(subpath)

self.parse_path(subpath)

self.resolve_dependencies()

await self.resolve_descriptions_async()

def parse_path(self, subpath: Path, parent: Directory = None):
if subpath.is_dir():
logger.debug(f"🔍 Parsing directory '{subpath}'")

if parent is None:
# ROOT creation
root_path = subpath
if self.registry:
root_path = self.registry.relative_to_project(subpath)
root = Directory(path=root_path, registry=self.registry)
self.registry.root = root
logger.debug(f"Created registry root: {root}")
self.registry.add_struct(root)
else:
root = parent
Expand All @@ -58,36 +52,27 @@ def parse_path(self, subpath: Path, parent: Directory = None):
existing = self.registry.get_struct_by_uid(str(relative_path))

if path.is_dir():
if existing and isinstance(existing, Directory):
logger.debug(f"Using cached directory '{path}'")
root.add_child(existing)
self.parse_path(path, parent=existing)
continue

directory = Directory(path=relative_path, registry=self.registry, parent=root)
self.registry.add_struct(directory)
root.add_child(directory)
self.parse_path(path, parent=directory)
else:
if existing and isinstance(existing, BaseFile):
if not self.registry.is_stale(existing):
logger.debug(f"Using cached file '{path}'")
root.add_child(existing)
continue

file = self.parse_file(path, parent=root)
if file:
file.calculate_distributed_hash()
self.registry.add_struct(file)
root.add_child(file)

root.calculate_distributed_hash()
else:
logger.debug(f"🔍 Parsing file '{subpath}'")
file = self.parse_file(subpath)
self.registry.root = file
self.registry.add_struct(file)
if file:
file.calculate_distributed_hash()
self.registry.root = file
self.registry.add_struct(file)

# @abstractmethod
def parse_file(self, subpath: Path, parent: BaseStruct=None) -> BaseFile:
logger.debug(f"Attempting to resolve builder for suffix {subpath.parts[-1]}")
logger.debug(f"Attempting to resolve builder for suffix {subpath.suffix}")
if self.registry.config.is_ignored(subpath):
logger.debug(f"Skipping '{subpath}' due to path ignore rules")
return None
Expand All @@ -98,34 +83,19 @@ def parse_file(self, subpath: Path, parent: BaseStruct=None) -> BaseFile:
logger.warning(str(e))
return None
file_obj = builder.build_file().from_path(subpath, parent=parent)
# logger.debug(json.dumps(file_obj.to_dict(), indent=2))
return file_obj

def resolve_dependencies(self):
logger.info(f"Starting dependency resolution from root: {self.registry.root}")
self.registry.root.resolve_dependencies()
if self.registry.root:
logger.info(f"Starting dependency resolution from root: {self.registry.root}")
self.registry.root.resolve_dependencies()

def load_cache(self):
# print("Attempting to load cache from SQLite database...")
# t_cache = time.time()
self.registry.load_cache()
# print(f"✅ Loaded Cache in {time.time() - t_cache:.2f} seconds")

async def resolve_descriptions_async(self):
self.visited_ucids = set()
coroutine_list = [file.resolve_description_async(self.llm, self.visited_ucids) for file in self.registry.files]
if coroutine_list == []: return
result = await asyncio.gather(*coroutine_list)


# def write_skeleton(self):
# tost_string = tost.dump_parser(self, verbosity=Verbosity.SIMPLE)
# tostr_dir = self.path / ".tostr"
# tostr_dir.mkdir(exist_ok=True)
# with open(tostr_dir / "skeleton.tost", "w") as file:
# file.write(tost_string)

# def write_cache(self, stale: bool = False):
# logger.debug("Writing AST to SQLite database...")
# self.registry.save_to_cache(stale=stale)

visited_ucids = set()
coroutine_list = [file.resolve_description_async(self.llm, visited_ucids) for file in self.registry.files]
if not coroutine_list:
return
await asyncio.gather(*coroutine_list)
Loading