diff --git a/.claude/settings.local.json b/.claude/settings.local.json
index 24b715a..80065c6 100644
--- a/.claude/settings.local.json
+++ b/.claude/settings.local.json
@@ -1,6 +1,10 @@
{
- "enabledMcpjsonServers": ["codeembed"],
+ "enabledMcpjsonServers": [
+ "codeembed"
+ ],
"permissions": {
- "allow": ["mcp__codeembed__search"]
+ "allow": [
+ "mcp__codeembed__search"
+ ]
}
}
diff --git a/.mcp.json b/.mcp.json
index 48b640c..3bc24a4 100644
--- a/.mcp.json
+++ b/.mcp.json
@@ -1,8 +1,10 @@
{
"mcpServers": {
"codeembed": {
- "command": "uv",
- "args": ["run", "codeembed", "serve"]
+ "command": "codeembed",
+ "args": [
+ "serve"
+ ]
}
}
}
diff --git a/.vscode/mcp.json b/.vscode/mcp.json
index 18df0a5..dbd1245 100644
--- a/.vscode/mcp.json
+++ b/.vscode/mcp.json
@@ -1,8 +1,10 @@
{
"servers": {
"codeembed": {
- "command": "uv",
- "args": ["run", "codeembed", "serve"]
+ "command": "codeembed",
+ "args": [
+ "serve"
+ ]
}
}
}
diff --git a/README.md b/README.md
index 6899b78..f189f35 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# CodeEmbed
-Embeds your codebase into a local vector database and exposes it as an MCP tool, giving AI assistants like Claude Code fast semantic search over your code.
+Embeds your codebase into a local vector and graph database and exposes it as an MCP tool, giving AI assistants like Claude Code fast semantic search over your code using Graph RAG.
Particularly useful for questions like:
@@ -12,7 +12,7 @@ For other questions, the agent will fall back to normal lookups.
CodeEmbed can improve lookup speed and accuracy, especially for finding existing implementations before writing new ones.
Note that the biggest bottleneck in coding agents is LLM thinking and token generation — solid prompts and follow-up questions still matter.
-Uses [ChromaDB](https://github.com/chroma-core/chroma) for local vector storage and either [Ollama](https://github.com/ollama/ollama) or OpenAI (including OpenAI models via Azure AI Foundry) for LLM analysis.
+Uses [ChromaDB](https://github.com/chroma-core/chroma) for vector storage, SQLite for graph storage, and either [Ollama](https://github.com/ollama/ollama) or OpenAI (including OpenAI models via Azure AI Foundry) for LLM analysis.
## Prerequisites
diff --git a/pyproject.toml b/pyproject.toml
index 221f98f..31b904b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "codeembed"
-version = "0.1.1"
+version = "0.2.0b3"
description = "Embeds your codebase and makes it available for quick LLM lookups via MCP."
readme = "README.md"
requires-python = ">=3.11"
diff --git a/src/codeembed/bootstrap/services.py b/src/codeembed/bootstrap/services.py
index beb9e80..3e98c2c 100644
--- a/src/codeembed/bootstrap/services.py
+++ b/src/codeembed/bootstrap/services.py
@@ -10,6 +10,7 @@
from codeembed.doc_embedder.doc_embedder import DocEmbedder
from codeembed.doc_provider.local_doc_provider import LocalDocProvider
from codeembed.doc_search_service.doc_search_service import DocSearchService
+from codeembed.graph_db.sqlite_adapter import SqliteGraphDb
from codeembed.llm.base import LLMServiceBase
from codeembed.llm.ollama_adapter import OllamaLLMService
from codeembed.vector_db.chromadb_adapter import ChromaDbAdapter
@@ -23,10 +24,21 @@
_DEFAULT_SLEEP_INTERVAL = 60
+@lru_cache(maxsize=1)
+def get_vector_db() -> ChromaDbAdapter:
+ return ChromaDbAdapter(collection_name="codebase")
+
+
+@lru_cache(maxsize=1)
+def get_graph_db() -> SqliteGraphDb:
+ return SqliteGraphDb(db_path=".codeembed/graph.db")
+
+
@lru_cache(maxsize=1)
def get_search_service() -> DocSearchService:
- vector_db = ChromaDbAdapter(collection_name="codebase")
- search_service = DocSearchService(vector_db)
+ vector_db = get_vector_db()
+ graph_db = get_graph_db()
+ search_service = DocSearchService(vector_db, graph_db)
return search_service
@@ -195,10 +207,11 @@ def get_embedder_service() -> DocEmbedder:
base_path=".",
supported_file_extensions=_SUPPORTED_FILE_EXTENSIONS,
)
- vector_db = ChromaDbAdapter(collection_name="codebase")
+ vector_db = get_vector_db()
llm_service = get_llm_service()
+ graph_db = get_graph_db()
embedder = DocEmbedder(
- doc_provider, vector_db, llm_service, llm_model=config.llm_model, debounce_seconds=config.debounce
+ doc_provider, vector_db, graph_db, llm_service, llm_model=config.llm_model, debounce_seconds=config.debounce
)
return embedder
diff --git a/src/codeembed/cli.py b/src/codeembed/cli.py
index 08fa7e2..556443f 100644
--- a/src/codeembed/cli.py
+++ b/src/codeembed/cli.py
@@ -418,6 +418,18 @@ def serve():
mcp.run(transport="stdio")
+@app.command()
+def search(
+ query: str = typer.Argument(..., help="Natural-language search query"),
+ top_n: int = typer.Option(10, "--top-n", "-n", help="Number of results to return"),
+):
+ """Search the embedded codebase using semantic similarity."""
+ from codeembed.bootstrap.services import get_search_service
+
+ result = get_search_service().search(query, top_n)
+ typer.echo(result)
+
+
@app.command()
def embed():
"""Embed codebase into the vector database."""
diff --git a/src/codeembed/delta_computer/delta_computer.py b/src/codeembed/delta_computer/delta_computer.py
index abd40a6..3dd693c 100644
--- a/src/codeembed/delta_computer/delta_computer.py
+++ b/src/codeembed/delta_computer/delta_computer.py
@@ -15,7 +15,7 @@ def __init__(self, doc_provider: DocProviderBase, vector_db: VectorDbBase, debou
self._vector_db = vector_db
self._debounce_seconds = debounce_seconds
- def compute_deltas(self) -> Tuple[Set[UUID], Set[str]]:
+ def compute_deltas(self) -> Tuple[Set[UUID], Set[str], Set[str]]:
"""
Returns chunk IDs to delete and file paths to process.
@@ -26,6 +26,7 @@ def compute_deltas(self) -> Tuple[Set[UUID], Set[str]]:
file_path_to_chunk_ids: Dict[str, List[UUID]] = {}
chunk_ids_to_delete: Set[UUID] = set()
+ file_paths_to_delete: Set[str] = set()
# Collect modified_at stored in our database.
old_modified_at: Dict[str, datetime] = {}
@@ -71,5 +72,6 @@ def compute_deltas(self) -> Tuple[Set[UUID], Set[str]]:
if file_path not in current:
for chunk_id in file_path_to_chunk_ids.get(file_path, []):
chunk_ids_to_delete.add(chunk_id)
+ file_paths_to_delete.add(file_path)
- return chunk_ids_to_delete, file_paths_to_update
+ return chunk_ids_to_delete, file_paths_to_update, file_paths_to_delete
diff --git a/src/codeembed/doc_embedder/doc_embedder.py b/src/codeembed/doc_embedder/doc_embedder.py
index 8fe97a5..710527e 100644
--- a/src/codeembed/doc_embedder/doc_embedder.py
+++ b/src/codeembed/doc_embedder/doc_embedder.py
@@ -2,10 +2,14 @@
from typing import List
from uuid import uuid4
+from pydantic import BaseModel
+
from codeembed.delta_computer.delta_computer import DeltaComputer
from codeembed.doc_provider.base import DocProviderBase
from codeembed.doc_splitters.generic_splitter import FileSplitter
from codeembed.doc_splitters.models import FileSegment
+from codeembed.graph_db.base import GraphDbBase
+from codeembed.graph_db.models import Edge
from codeembed.llm.base import LLMServiceBase
from codeembed.llm.models import ChatMessage
from codeembed.vector_db.base import VectorDbBase
@@ -14,7 +18,7 @@
logger = logging.getLogger(__name__)
-def _segment_to_chunk(
+def _summarize_chunk_with_llm(
llm_service: LLMServiceBase,
segment: FileSegment,
full_content: str,
@@ -25,7 +29,13 @@ def _segment_to_chunk(
# NOTE: For markdown files we could embed directly without LLM summarization.
# Just split on ## headers.
- logger.info("Analyzing segment %s in file %s...", segment.content.split("\n")[0], file_path)
+ logger.info(
+ "Summarizing segment %s in file %s:%d-%d...",
+ segment.content.split("\n")[0],
+ file_path,
+ segment.line_start,
+ segment.line_end,
+ )
messages: List[ChatMessage] = [
{"role": "system", "content": "You are an expert at describing code."},
@@ -55,24 +65,150 @@ def _segment_to_chunk(
},
]
- result = llm_service.generate_response(messages, llm_model, max_tokens=1024, temperature=0.3)
+ result = llm_service.generate_response(messages, llm_model, max_tokens=4096, temperature=0.3)
- logger.info("Generated summary for segment in file %s: %s", file_path, result.response)
+ logger.info("Generated summary for segment in file %s. Length: %d", file_path, len(result.response))
return result.response
+class _Edge(BaseModel):
+ source: str
+ relation: str
+ target: str
+
+
+class _GraphOutput(BaseModel):
+ edges: List[_Edge] # source, relation, target
+
+
+def _normalize_edge(edge: _Edge) -> _Edge:
+ return _Edge(
+ source=edge.source.strip(),
+ relation=edge.relation.strip().upper().replace(" ", "_"),
+ target=edge.target.strip(),
+ )
+
+
+def _find_graph_relations_with_llm(
+ llm_service: LLMServiceBase,
+ segment: FileSegment,
+ full_content: str,
+ file_path: str,
+ llm_model: str,
+ summary: str,
+) -> List[_Edge]:
+
+ logger.info(
+ "Extracting graph relations for segment %s in file %s:%d-%d...",
+ segment.content.split("\n")[0],
+ file_path,
+ segment.line_start,
+ segment.line_end,
+ )
+
+ messages: List[ChatMessage] = [
+ {
+ "role": "system",
+ "content": (
+ "You extract relationships from code or text and return structured graph edges.\n\n"
+ "Node ID format rules — follow these exactly:\n"
+ "- Class method: ClassName.method_name e.g. AuthService.login\n"
+ "- Class: ClassName e.g. AuthService\n"
+ "- Module-level func: function_name e.g. jwt_decode\n\n"
+ "Allowed relations (use ONLY these):\n"
+ "- CALLS — a function or method invokes another\n"
+ "- EXTENDS — a class inherits from another\n"
+ "- IMPLEMENTS — a class implements an interface or abstract base\n\n"
+ "Forbidden:\n"
+ "- Never emit IMPORTS edges — import statements are visible in the raw code and add no value\n"
+ "- Never emit USES edges\n"
+ "- Never use module paths (e.g. codeembed.llm.models) as node IDs — always prefer the class or "
+ "function name\n\n"
+ "Other rules:\n"
+ "- Only output relations explicitly present in the code or text\n"
+ "- Use full file context to resolve ambiguous references\n"
+ "- Do NOT invent nodes or relations\n"
+ "- Ignore trivial variable assignments and local-only references\n"
+ "- Node IDs must be real code symbols (class names, function names) — NEVER string "
+ "literals, single letters, placeholder values, or test input data\n"
+ ),
+ },
+ {
+ "role": "user",
+ "content": f"""
+{file_path}
+
+
+{summary}
+
+
+
+{full_content}
+
+
+
+{segment.content}
+
+
+Extract graph relations from the Segment only (use FullFileContent for context/disambiguation).
+
+Examples of correct output:
+ AuthService.login CALLS UserRepository.find_by_email
+ AuthService.login CALLS JwtService.sign
+ UserRepository EXTENDS BaseRepository
+ OllamaLLMService IMPLEMENTS LLMServiceBase
+
+Return STRICT JSON:
+{{
+ "edges": [
+ {{
+ "source": "...",
+ "relation": "...",
+ "target": "..."
+ }}
+ ]
+}}
+""",
+ },
+ ]
+
+ try:
+ result = llm_service.generate_structured_output(
+ messages=messages,
+ llm_model=llm_model,
+ output_format=_GraphOutput,
+ max_tokens=4096,
+ temperature=0.1,
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to extract graph relations for segment in '%s' (lines %s-%s): %s",
+ file_path,
+ segment.line_start,
+ segment.line_end,
+ e,
+ )
+ return []
+
+ logger.info("Extracted %d graph edges for segment in file %s.", len(result.data.edges), file_path)
+
+ return [_normalize_edge(edge) for edge in result.data.edges]
+
+
class DocEmbedder:
def __init__(
self,
doc_provider: DocProviderBase,
vector_db: VectorDbBase,
+ graph_db: GraphDbBase,
llm_service: LLMServiceBase,
llm_model: str,
debounce_seconds: int = 10,
) -> None:
self._doc_provider = doc_provider
self._vector_db = vector_db
+ self._graph_db = graph_db
self._llm_service = llm_service
self._llm_model = llm_model
self._debounce_seconds = debounce_seconds
@@ -82,17 +218,22 @@ def embed_codebase(self) -> None:
logger.info("Computing deltas...")
- chunks_ids_to_remove, files_to_update = DeltaComputer(
+ chunks_ids_to_remove, files_to_update, file_paths_to_delete = DeltaComputer(
self._doc_provider, self._vector_db, self._debounce_seconds
).compute_deltas()
- logger.info(f"Detected {len(chunks_ids_to_remove)} chunks to delete from vector database.")
- logger.info(f"Detected {len(files_to_update)} files to process.")
-
if chunks_ids_to_remove:
logger.info(f"Deleting {len(chunks_ids_to_remove)} chunks from vector database.")
self._vector_db.delete_chunks(list(chunks_ids_to_remove))
+ for file_path in files_to_update | file_paths_to_delete:
+ logger.info(f"Deleting edges for file '{file_path}' from graph database.")
+ self._graph_db.delete_edges_by_file_path(file_path)
+
+ if not files_to_update:
+ logger.info("No files to update. Embedding process is complete.")
+ return
+
logger.info(f"Processing {len(files_to_update)} files...")
num_processed = 0
@@ -100,29 +241,59 @@ def embed_codebase(self) -> None:
splitter = FileSplitter()
+ # TODO: Add multi-threading.
+
for i, file in enumerate(files_to_update):
logger.info(f"Processing file '{file}' ({i + 1}/{len(files_to_update)})...")
doc = self._doc_provider.get_content(file)
segments = splitter.split_file(doc.content, file)
chunks = []
+ edges: List[Edge] = []
for segment in segments:
- summary = _segment_to_chunk(self._llm_service, segment, doc.content, file, self._llm_model)
- chunks.append(
- Chunk(
- id=uuid4(),
- modified_at=doc.modified_at,
- content=summary,
- file_path=file,
- line_start=segment.line_start,
- line_end=segment.line_end,
- raw_code=segment.content,
- file_sha256_checksum=doc.sha256_checksum,
- )
+ summary = _summarize_chunk_with_llm(self._llm_service, segment, doc.content, file, self._llm_model)
+
+ _edges = _find_graph_relations_with_llm(
+ self._llm_service, segment, doc.content, file, self._llm_model, summary
+ )
+
+ chunk = Chunk(
+ id=uuid4(),
+ modified_at=doc.modified_at,
+ content=summary,
+ file_path=file,
+ line_start=segment.line_start,
+ line_end=segment.line_end,
+ raw_code=segment.content,
+ file_sha256_checksum=doc.sha256_checksum,
+ graph_node_ids=[edge.source for edge in _edges],
+ )
+
+ chunks.append(chunk)
+
+ edges.extend(
+ [
+ Edge(
+ source=edge.source,
+ relation=edge.relation,
+ target=edge.target,
+ file_path=file,
+ chunk_id=chunk.id,
+ properties={
+ "line_start": segment.line_start,
+ "line_end": segment.line_end,
+ "modified_at": doc.modified_at.isoformat(),
+ "file_sha256_checksum": doc.sha256_checksum,
+ },
+ )
+ for edge in _edges
+ ]
)
if not chunks:
logger.warning(f"No chunks generated for file '{file}'. Skipping embedding for this file.")
num_skipped += 1
continue
+ logger.info(f"Saving {len(edges)} edges to graph database.")
+ self._graph_db.add_edges(edges)
logger.info(f"Saving {len(chunks)} chunks to vector database.")
self._vector_db.add_chunks(chunks)
num_processed += 1
diff --git a/src/codeembed/doc_provider/local_doc_provider.py b/src/codeembed/doc_provider/local_doc_provider.py
index 9fe4a33..e362c6b 100644
--- a/src/codeembed/doc_provider/local_doc_provider.py
+++ b/src/codeembed/doc_provider/local_doc_provider.py
@@ -6,7 +6,7 @@
from codeembed.doc_provider.base import DocProviderBase
from codeembed.doc_provider.models import DocumentContent, DocumentMeta
-_SKIP_DIRS = frozenset({"venv", ".venv", "node_modules", "dist", "build"})
+_SKIP_DIRS = frozenset({"venv", ".venv", "node_modules", "dist", "build", "tests"})
_SKIP_FILES = frozenset({"__init__.py", ".env", ".env.local", "appsettings.json", "appsettings.Development.json"})
diff --git a/src/codeembed/doc_search_service/doc_search_service.py b/src/codeembed/doc_search_service/doc_search_service.py
index 12d14df..e0dc86e 100644
--- a/src/codeembed/doc_search_service/doc_search_service.py
+++ b/src/codeembed/doc_search_service/doc_search_service.py
@@ -1,8 +1,37 @@
from typing import Dict, List
+from codeembed.graph_db.base import GraphDbBase
+from codeembed.graph_db.models import Subgraph
from codeembed.utils.string_utils import truncate_string
from codeembed.vector_db.base import VectorDbBase
-from codeembed.vector_db.models import Chunk
+from codeembed.vector_db.models import Chunk, SearchResult
+
+_GRAPH_WEIGHT = 0.5
+
+
+def _rerank(
+ search_results: List[SearchResult],
+ graph_results: Subgraph,
+ graph_chunks: List[Chunk],
+) -> List[Chunk]:
+ chunk_id_to_depth: Dict[str, int] = {}
+ for edge in graph_results.edges:
+ depth = graph_results.depths.get(edge.source)
+ if depth is None:
+ continue
+ chunk_id = str(edge.chunk_id)
+ if chunk_id not in chunk_id_to_depth or depth < chunk_id_to_depth[chunk_id]:
+ chunk_id_to_depth[chunk_id] = depth
+
+ score_by_chunk_id: Dict[str, float] = {}
+ for r in search_results:
+ score_by_chunk_id[str(r.chunk.id)] = 1.0 / (1.0 + r.score)
+ for chunk_id, depth in chunk_id_to_depth.items():
+ score_by_chunk_id[chunk_id] = score_by_chunk_id.get(chunk_id, 0.0) + _GRAPH_WEIGHT / (depth + 1)
+
+ all_chunks: Dict[str, Chunk] = {str(r.chunk.id): r.chunk for r in search_results}
+ all_chunks.update({str(c.id): c for c in graph_chunks})
+ return sorted(all_chunks.values(), key=lambda c: score_by_chunk_id.get(str(c.id), 0.0), reverse=True)
class DocSearchService:
@@ -13,23 +42,55 @@ class DocSearchService:
def __init__(
self,
vector_db: VectorDbBase,
+ graph_db: GraphDbBase,
) -> None:
self._vector_db = vector_db
+ self._graph_db = graph_db
def search(self, query: str, top_n: int = 10) -> str:
"""Searches for relevant content from vector database and formats it for LLM consumption."""
- chunks = self._vector_db.search(query, top_n)
- chunks_by_file: Dict[str, List[Chunk]] = {}
+ # Initial vector search using semantic similarity (vector embedding).
+ search_results = self._vector_db.search(query, top_n)
+
+ # Get initial graph relations
+ node_ids = list(set(node_id for result in search_results for node_id in result.chunk.graph_node_ids))
- for chunk in chunks:
+ # Expand with GraphRAG (graph traversal) to get nearby chunks.
+ graph_results = self._graph_db.expand_nodes(node_ids, max_depth=2)
+
+ # Fetch chunks for the graph results.
+ initial_chunk_ids = list(set(result.chunk.id for result in search_results))
+ graph_chunk_ids = list(
+ set(edge.chunk_id for edge in graph_results.edges if edge.chunk_id not in initial_chunk_ids)
+ )
+ graph_chunks = self._vector_db.get_chunks(graph_chunk_ids)
+
+ ranked_chunks = _rerank(search_results, graph_results, graph_chunks)
+
+ # Collect all node IDs represented in ranked chunks for edge filtering.
+ ranked_node_ids = set(node_id for chunk in ranked_chunks for node_id in chunk.graph_node_ids)
+
+ # Keep only edges where both endpoints are visible in context. Deduplicate by (source, relation, target).
+ seen_edges: set = set()
+ visible_edges = []
+ for edge in graph_results.edges:
+ key = (edge.source, edge.relation, edge.target)
+ if key in seen_edges:
+ continue
+ if edge.source in ranked_node_ids and edge.target in ranked_node_ids:
+ seen_edges.add(key)
+ visible_edges.append(edge)
+
+ chunks_by_file: Dict[str, List[Chunk]] = {}
+ for chunk in ranked_chunks:
if chunk.file_path not in chunks_by_file:
chunks_by_file[chunk.file_path] = []
chunks_by_file[chunk.file_path].append(chunk)
res = f"{query}\n"
res += f"{top_n}\n"
- res += f"\n"
+ res += f"\n"
for file_path, chunks in chunks_by_file.items():
res += f' \n'
for chunk in chunks:
@@ -44,5 +105,10 @@ def search(self, query: str, top_n: int = 10) -> str:
)
res += " \n"
res += " \n"
+ if visible_edges:
+ res += " \n"
+ for edge in visible_edges:
+ res += f' \n'
+ res += " \n"
res += "\n"
return res
diff --git a/src/codeembed/graph_db/base.py b/src/codeembed/graph_db/base.py
new file mode 100644
index 0000000..f261604
--- /dev/null
+++ b/src/codeembed/graph_db/base.py
@@ -0,0 +1,44 @@
+from abc import ABC, abstractmethod
+from typing import List, Optional, Set
+
+from codeembed.graph_db.models import Edge, Subgraph
+
+
+class GraphDbBase(ABC):
+ # ------------------------
+ # Traversal
+ # ------------------------
+
+ @abstractmethod
+ def expand_nodes(
+ self,
+ node_ids: List[str],
+ max_depth: int = 1,
+ relations: Optional[Set[str]] = None,
+ ) -> Subgraph:
+ pass
+
+ # ------------------------
+ # Edge operations
+ # ------------------------
+
+ @abstractmethod
+ def add_edge(self, edge: Edge) -> None:
+ pass
+
+ @abstractmethod
+ def add_edges(self, edges: List[Edge]) -> None:
+ pass
+
+ @abstractmethod
+ def delete_edges_by_file_path(self, file_path: str) -> None:
+ pass
+
+ @abstractmethod
+ def delete_edge(
+ self,
+ source: str,
+ target: str,
+ relation: str,
+ ) -> None:
+ pass
diff --git a/src/codeembed/graph_db/models.py b/src/codeembed/graph_db/models.py
new file mode 100644
index 0000000..562c37f
--- /dev/null
+++ b/src/codeembed/graph_db/models.py
@@ -0,0 +1,19 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, List
+from uuid import UUID
+
+
+@dataclass
+class Edge:
+ source: str
+ target: str
+ relation: str
+ file_path: str
+ chunk_id: UUID
+ properties: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class Subgraph:
+ edges: List[Edge] = field(default_factory=list)
+ depths: Dict[str, int] = field(default_factory=dict)
diff --git a/src/codeembed/graph_db/sqlite_adapter.py b/src/codeembed/graph_db/sqlite_adapter.py
new file mode 100644
index 0000000..68ae24b
--- /dev/null
+++ b/src/codeembed/graph_db/sqlite_adapter.py
@@ -0,0 +1,126 @@
+import json
+import sqlite3
+from typing import List, Optional, Set
+from uuid import UUID
+
+from codeembed.graph_db.base import GraphDbBase
+from codeembed.graph_db.models import Edge, Subgraph
+
+
+class SqliteGraphDb(GraphDbBase):
+ def __init__(self, db_path: str):
+ self._conn = sqlite3.connect(db_path)
+ self._cur = self._conn.cursor()
+ self._initialize_schema()
+
+ # ------------------------
+ # Traversal (GraphRAG core)
+ # ------------------------
+
+ def expand_nodes(
+ self,
+ node_ids: List[str],
+ max_depth: int = 1,
+ relations: Optional[Set[str]] = None,
+ ) -> Subgraph:
+
+ visited = set()
+ frontier = set(node_ids)
+
+ all_nodes = set(node_ids)
+ depths = {n: 0 for n in node_ids}
+ edges = []
+
+ for depth in range(1, max_depth + 1):
+ if not frontier:
+ break
+
+ placeholders = ",".join(["?"] * len(frontier))
+
+ query = f"""
+ SELECT source, target, relation, file_path, chunk_id, properties
+ FROM edges
+ WHERE source IN ({placeholders})
+ """
+
+ params = list(frontier)
+
+ if relations:
+ rel_placeholders = ",".join(["?"] * len(relations))
+ query += f" AND relation IN ({rel_placeholders})"
+ params += list(relations)
+
+ self._cur.execute(query, params)
+ rows = self._cur.fetchall()
+
+ next_frontier = set()
+
+ for src, tgt, rel, file_path, chunk_id, properties in rows:
+ _properties = json.loads(properties) if properties else {}
+ edges.append(
+ Edge(
+ source=src,
+ target=tgt,
+ relation=rel,
+ file_path=file_path,
+ chunk_id=UUID(chunk_id),
+ properties=_properties,
+ )
+ )
+ if tgt not in visited:
+ next_frontier.add(tgt)
+ depths[tgt] = depth
+
+ visited |= frontier
+ frontier = next_frontier
+ all_nodes |= next_frontier
+
+ return Subgraph(edges=edges, depths=depths)
+
+ # ------------------------
+ # Edge operations
+ # ------------------------
+
+ def add_edge(self, edge: Edge) -> None:
+ self.add_edges([edge])
+
+ def add_edges(self, edges: List[Edge]) -> None:
+ self._cur.executemany(
+ "INSERT OR REPLACE INTO edges (source, target, relation, file_path, chunk_id, properties) "
+ "VALUES (?, ?, ?, ?, ?, ?)",
+ [(e.source, e.target, e.relation, e.file_path, str(e.chunk_id), json.dumps(e.properties)) for e in edges],
+ )
+ self._conn.commit()
+
+ def delete_edge(self, source: str, target: str, relation: str) -> None:
+ self._cur.execute(
+ "DELETE FROM edges WHERE source = ? AND target = ? AND relation = ?",
+ (source, target, relation),
+ )
+ self._conn.commit()
+
+ def delete_edges_by_file_path(self, file_path: str) -> None:
+ self._cur.execute(
+ "DELETE FROM edges WHERE file_path = ?",
+ (file_path,),
+ )
+ self._conn.commit()
+
+ # ------------------------
+ # Schema
+ # ------------------------
+
+ def _initialize_schema(self):
+ self._conn.execute("""
+ CREATE TABLE IF NOT EXISTS edges (
+ source TEXT,
+ target TEXT,
+ relation TEXT,
+ file_path TEXT,
+ chunk_id TEXT,
+ properties TEXT,
+ PRIMARY KEY (source, target, relation)
+ )
+ """)
+
+ self._conn.commit()
diff --git a/src/codeembed/llm/ollama_adapter.py b/src/codeembed/llm/ollama_adapter.py
index 5721f89..6d9bdd9 100644
--- a/src/codeembed/llm/ollama_adapter.py
+++ b/src/codeembed/llm/ollama_adapter.py
@@ -29,6 +29,9 @@ def generate_structured_output(
data = resp["message"]["content"]
+ if not data or not data.strip():
+ raise ValueError(f"Ollama returned empty response for model '{llm_model}'")
+
model = output_format.model_validate_json(data)
return StructuredLLMResponse(
diff --git a/src/codeembed/vector_db/base.py b/src/codeembed/vector_db/base.py
index 8ddedfd..57d0289 100644
--- a/src/codeembed/vector_db/base.py
+++ b/src/codeembed/vector_db/base.py
@@ -2,7 +2,7 @@
from typing import Dict, Iterator, List, Optional
from uuid import UUID
-from codeembed.vector_db.models import Chunk
+from codeembed.vector_db.models import Chunk, SearchResult
class VectorDbBase(ABC):
@@ -11,7 +11,7 @@ def add_chunks(self, chunks: List[Chunk]) -> None:
pass
@abstractmethod
- def search(self, query: str, top_n: int) -> List[Chunk]:
+ def search(self, query: str, top_n: int) -> List[SearchResult]:
"""Vector search. Returns top_n most relevant results."""
@abstractmethod
@@ -22,6 +22,10 @@ def iter_chunks(self, where: Optional[Dict[str, str]] = None) -> Iterator[Chunk]
For simplicity exposes 'where' argument which is a ChromaDB specific filter.
"""
+ @abstractmethod
+ def get_chunks(self, chunk_ids: List[UUID]) -> List[Chunk]:
+ pass
+
@abstractmethod
def delete_chunks(self, chunk_ids: List[UUID]) -> None:
pass
diff --git a/src/codeembed/vector_db/chromadb_adapter.py b/src/codeembed/vector_db/chromadb_adapter.py
index 73afcef..dc80da5 100644
--- a/src/codeembed/vector_db/chromadb_adapter.py
+++ b/src/codeembed/vector_db/chromadb_adapter.py
@@ -1,3 +1,4 @@
+import json
from datetime import datetime
from typing import Dict, Iterator, List, Optional, Type, TypeVar
from uuid import UUID
@@ -6,7 +7,7 @@
from chromadb.api.types import Metadata, QueryResult
from codeembed.vector_db.base import VectorDbBase
-from codeembed.vector_db.models import Chunk
+from codeembed.vector_db.models import Chunk, SearchResult
T = TypeVar("T")
@@ -29,6 +30,7 @@ def add_chunks(self, chunks: List[Chunk]) -> None:
"line_end": chunk.line_end,
"raw_code": chunk.raw_code, # Assume ChromaDB can handle None values.
"file_sha256_checksum": chunk.file_sha256_checksum,
+ "graph_node_ids": json.dumps(chunk.graph_node_ids),
}
for chunk in chunks
]
@@ -39,18 +41,22 @@ def add_chunks(self, chunks: List[Chunk]) -> None:
metadatas=metadatas,
)
- def search(self, query: str, top_n: int) -> List[Chunk]:
+ def search(self, query: str, top_n: int) -> List[SearchResult]:
# TODO: Support filtering.
+ count = self._collection.count()
+ if count == 0:
+ return []
results: QueryResult = self._collection.query(
query_texts=[query],
- n_results=top_n,
+ n_results=min(top_n, count),
)
ids = results["ids"][0]
docs = results["documents"][0] # type: ignore
metas = results["metadatas"][0] # type: ignore
+ distances = results["distances"][0] # type: ignore
- chunks_out: List[Chunk] = []
+ chunks_out: List[SearchResult] = []
for i in range(len(ids)):
# Can be simplified by adding a "get_safe_val" or similar.
@@ -61,16 +67,21 @@ def search(self, query: str, top_n: int) -> List[Chunk]:
line_end = self._get_safe_val(metas[i], "line_end", int)
raw_code = self._get_safe_val(metas[i], "raw_code", str, allow_none=True)
file_sha256_checksum = self._get_safe_val(metas[i], "file_sha256_checksum", str)
+ graph_node_ids = self._get_safe_list_val(metas[i], "graph_node_ids", str, allow_none=True)
chunks_out.append(
- Chunk(
- id=UUID(ids[i]),
- content=docs[i],
- modified_at=modified_at,
- file_path=file_path,
- line_start=line_start,
- line_end=line_end,
- raw_code=raw_code,
- file_sha256_checksum=file_sha256_checksum,
+ SearchResult(
+ chunk=Chunk(
+ id=UUID(ids[i]),
+ content=docs[i],
+ modified_at=modified_at,
+ file_path=file_path,
+ line_start=line_start,
+ line_end=line_end,
+ raw_code=raw_code,
+ file_sha256_checksum=file_sha256_checksum,
+ graph_node_ids=graph_node_ids,
+ ),
+ score=distances[i],
)
)
@@ -104,6 +115,7 @@ def iter_chunks(self, where: Optional[Dict[str, str]] = None) -> Iterator[Chunk]
line_end = self._get_safe_val(metas[i], "line_end", int)
raw_code = self._get_safe_val(metas[i], "raw_code", str, allow_none=True)
file_sha256_checksum = self._get_safe_val(metas[i], "file_sha256_checksum", str)
+ graph_node_ids = self._get_safe_list_val(metas[i], "graph_node_ids", str, allow_none=True)
yield Chunk(
id=UUID(ids[i]),
content=docs[i],
@@ -113,14 +125,72 @@ def iter_chunks(self, where: Optional[Dict[str, str]] = None) -> Iterator[Chunk]
line_end=line_end,
raw_code=raw_code,
file_sha256_checksum=file_sha256_checksum,
+ graph_node_ids=graph_node_ids,
)
offset += limit
+ def get_chunks(self, chunk_ids: List[UUID]) -> List[Chunk]:
+ if not chunk_ids:
+ return []
+ str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
+ results = self._collection.get(ids=str_chunk_ids)
+
+ ids = results["ids"]
+ docs = results["documents"] or []
+ metas = results["metadatas"] or []
+
+ chunks_out: List[Chunk] = []
+
+ for i in range(len(ids)):
+ # Can be simplified by adding a "get_safe_val" or similar.
+ modified_at = self._get_safe_val(metas[i], "modified_at", str)
+ modified_at = datetime.fromisoformat(modified_at)
+ file_path = self._get_safe_val(metas[i], "file_path", str)
+ line_start = self._get_safe_val(metas[i], "line_start", int)
+ line_end = self._get_safe_val(metas[i], "line_end", int)
+ raw_code = self._get_safe_val(metas[i], "raw_code", str, allow_none=True)
+ file_sha256_checksum = self._get_safe_val(metas[i], "file_sha256_checksum", str)
+ graph_node_ids = self._get_safe_list_val(metas[i], "graph_node_ids", str, allow_none=True)
+ chunks_out.append(
+ Chunk(
+ id=UUID(ids[i]),
+ content=docs[i],
+ modified_at=modified_at,
+ file_path=file_path,
+ line_start=line_start,
+ line_end=line_end,
+ raw_code=raw_code,
+ file_sha256_checksum=file_sha256_checksum,
+ graph_node_ids=graph_node_ids,
+ )
+ )
+
+ return chunks_out
+
def delete_chunks(self, chunk_ids: List[UUID]) -> None:
# Maybe batch if list is very long? I hope ChromaDB does so internally.
self._collection.delete(ids=[str(chunk_id) for chunk_id in chunk_ids])
+ def _get_safe_list_val(
+ self, meta: Metadata, key: str, expected_elem_type: Type[T], allow_none: bool = False
+ ) -> List[T]:
+ val_str = self._get_safe_val(meta, key, str, allow_none=allow_none)
+ if val_str is None and allow_none:
+ return []
+ elif val_str is None:
+ raise ValueError(f"Expected a JSON string for key '{key}', got None.")
+ val_json = json.loads(val_str)
+ if not isinstance(val_json, list):
+ raise ValueError(f"Expected a list for key '{key}', got {type(val_json)}.")
+ for elem in val_json:
+ if not isinstance(elem, expected_elem_type):
+ raise ValueError(
+ f"Expected elements of type {expected_elem_type} in list for "
+ f"key '{key}', got element of type {type(elem)}."
+ )
+ return val_json
+
def _get_safe_val(self, meta: Metadata, key: str, expected_type: Type[T], allow_none: bool = False) -> T:
val = meta.get(key)
if val is None and allow_none:
diff --git a/src/codeembed/vector_db/models.py b/src/codeembed/vector_db/models.py
index c2a23fe..a7e85c1 100644
--- a/src/codeembed/vector_db/models.py
+++ b/src/codeembed/vector_db/models.py
@@ -1,6 +1,6 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from datetime import datetime
-from typing import Optional
+from typing import List, Optional
from uuid import UUID
@@ -14,3 +14,10 @@ class Chunk:
line_end: int
raw_code: Optional[str]
file_sha256_checksum: str
+ graph_node_ids: List[str] = field(default_factory=list)
+
+
+@dataclass
+class SearchResult:
+ chunk: Chunk
+ score: float
diff --git a/tests/test_delta_computer.py b/tests/test_delta_computer.py
index eec6cbf..d1a85bf 100644
--- a/tests/test_delta_computer.py
+++ b/tests/test_delta_computer.py
@@ -38,6 +38,9 @@ def add_chunks(self, *args, **kwargs):
def delete_chunks(self, *args, **kwargs):
raise NotImplementedError()
+ def get_chunks(self, *args, **kwargs):
+ raise NotImplementedError()
+
def test_detects_new_document():
now = utc_now()
@@ -51,7 +54,7 @@ def test_detects_new_document():
vector_db = FakeVectorDb([])
dc = DeltaComputer(doc_provider, vector_db, debounce_seconds=0)
- to_delete, to_update = dc.compute_deltas()
+ to_delete, to_update, _ = dc.compute_deltas()
assert to_delete == set()
assert to_update == {"file1.txt"}
@@ -79,7 +82,7 @@ def test_detects_deleted_document():
)
dc = DeltaComputer(doc_provider, vector_db)
- to_delete, to_update = dc.compute_deltas()
+ to_delete, to_update, _ = dc.compute_deltas()
assert to_delete == {chunk_id}
assert to_update == set()
@@ -113,7 +116,7 @@ def test_detects_updated_document():
)
dc = DeltaComputer(doc_provider, vector_db, debounce_seconds=0)
- to_delete, to_update = dc.compute_deltas()
+ to_delete, to_update, _ = dc.compute_deltas()
assert to_delete == {chunk_id}
assert to_update == {"file1.txt"}
diff --git a/tests/test_normalize_edge.py b/tests/test_normalize_edge.py
new file mode 100644
index 0000000..03cb098
--- /dev/null
+++ b/tests/test_normalize_edge.py
@@ -0,0 +1,34 @@
+from codeembed.doc_embedder.doc_embedder import _Edge, _normalize_edge
+
+
+def test_relation_uppercased():
+ edge = _Edge(source="A", relation="calls", target="B")
+ assert _normalize_edge(edge).relation == "CALLS"
+
+
+def test_relation_spaces_replaced_with_underscores():
+ edge = _Edge(source="A", relation="is supplier of", target="B")
+ assert _normalize_edge(edge).relation == "IS_SUPPLIER_OF"
+
+
+def test_relation_mixed_case_and_spaces():
+ edge = _Edge(source="A", relation="Is Supplier Of", target="B")
+ assert _normalize_edge(edge).relation == "IS_SUPPLIER_OF"
+
+
+def test_source_whitespace_stripped():
+ edge = _Edge(source=" AuthService.login ", relation="CALLS", target="B")
+ assert _normalize_edge(edge).source == "AuthService.login"
+
+
+def test_target_whitespace_stripped():
+ edge = _Edge(source="A", relation="CALLS", target=" jwt_decode ")
+ assert _normalize_edge(edge).target == "jwt_decode"
+
+
+def test_already_normalized_unchanged():
+ edge = _Edge(source="AuthService.login", relation="CALLS", target="JwtService.sign")
+ result = _normalize_edge(edge)
+ assert result.source == "AuthService.login"
+ assert result.relation == "CALLS"
+ assert result.target == "JwtService.sign"
diff --git a/tests/test_reranking.py b/tests/test_reranking.py
new file mode 100644
index 0000000..7de3803
--- /dev/null
+++ b/tests/test_reranking.py
@@ -0,0 +1,160 @@
+from datetime import datetime, timezone
+from typing import List, Optional
+from uuid import UUID, uuid4
+
+from codeembed.doc_search_service.doc_search_service import _rerank
+from codeembed.graph_db.models import Edge, Subgraph
+from codeembed.vector_db.models import Chunk, SearchResult
+
+_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
+
+
+def _chunk(chunk_id: UUID, node_ids: Optional[List[str]] = None) -> Chunk:
+ return Chunk(
+ id=chunk_id,
+ content="",
+ modified_at=_NOW,
+ file_path="f.py",
+ line_start=1,
+ line_end=10,
+ raw_code=None,
+ file_sha256_checksum="",
+ graph_node_ids=node_ids or [],
+ )
+
+
+def _result(chunk: Chunk, distance: float) -> SearchResult:
+ return SearchResult(chunk=chunk, score=distance)
+
+
+def _edge(source: str, target: str, chunk_id: UUID) -> Edge:
+ return Edge(source=source, target=target, relation="CALLS", file_path="f.py", chunk_id=chunk_id)
+
+
+def test_empty_graph_preserves_semantic_order():
+ id_a, id_b = uuid4(), uuid4()
+ chunk_a = _chunk(id_a)
+ chunk_b = _chunk(id_b)
+ results = [_result(chunk_a, distance=0.1), _result(chunk_b, distance=0.5)]
+
+ ranked = _rerank(results, Subgraph(), [])
+
+ assert ranked[0].id == id_a
+ assert ranked[1].id == id_b
+
+
+def test_graph_bonus_can_promote_lower_semantic_chunk():
+ id_a, id_b = uuid4(), uuid4()
+ chunk_a = _chunk(id_a, node_ids=["ServiceA"])
+ chunk_b = _chunk(id_b, node_ids=["ServiceB"])
+
+ # chunk_b has better semantic score (lower distance)
+ results = [_result(chunk_a, distance=0.5), _result(chunk_b, distance=0.4)]
+
+ # chunk_a gets a depth-1 graph bonus via an edge from "ServiceA"
+ subgraph = Subgraph(
+ edges=[_edge("ServiceA", "X", chunk_id=id_a)],
+ depths={"ServiceA": 1},
+ )
+
+ ranked = _rerank(results, subgraph, [])
+
+ # chunk_a: semantic=1/(1+0.5)=0.667, bonus=0.5/2=0.25, total=0.917
+ # chunk_b: semantic=1/(1+0.4)=0.714, no bonus, total=0.714
+ assert ranked[0].id == id_a
+
+
+def test_depth_1_ranks_higher_than_depth_2():
+ id_a, id_b = uuid4(), uuid4()
+ chunk_a = _chunk(id_a)
+ chunk_b = _chunk(id_b)
+
+ results = [_result(chunk_a, distance=1.0), _result(chunk_b, distance=1.0)]
+
+ subgraph = Subgraph(
+ edges=[
+ _edge("NodeA", "X", chunk_id=id_a),
+ _edge("NodeB", "X", chunk_id=id_b),
+ ],
+ depths={"NodeA": 1, "NodeB": 2},
+ )
+
+ ranked = _rerank(results, subgraph, [])
+
+ assert ranked[0].id == id_a
+
+
+def test_graph_only_chunk_included_in_results():
+ id_a, id_b = uuid4(), uuid4()
+ chunk_a = _chunk(id_a)
+ chunk_b = _chunk(id_b)
+
+ results = [_result(chunk_a, distance=0.2)]
+
+ subgraph = Subgraph(
+ edges=[_edge("NodeB", "X", chunk_id=id_b)],
+ depths={"NodeB": 1},
+ )
+
+ ranked = _rerank(results, subgraph, [chunk_b])
+
+ ids = [c.id for c in ranked]
+ assert id_b in ids
+
+
+def test_chunk_in_both_gets_combined_score():
+ id_a = uuid4()
+ chunk_a = _chunk(id_a)
+ results = [_result(chunk_a, distance=1.0)]
+
+ subgraph = Subgraph(
+ edges=[_edge("NodeA", "X", chunk_id=id_a)],
+ depths={"NodeA": 0},
+ )
+
+ ranked = _rerank(results, subgraph, [])
+
+ # semantic = 1/(1+1.0) = 0.5, graph bonus = 0.5/(0+1) = 0.5, total = 1.0
+ # Just verify it's in the results; score correctness implied by other tests
+ assert ranked[0].id == id_a
+
+
+def test_minimum_depth_wins_when_reachable_via_two_paths():
+ id_a, id_b = uuid4(), uuid4()
+ chunk_a = _chunk(id_a)
+ chunk_b = _chunk(id_b)
+
+ # Both chunks have the same semantic score
+ results = [_result(chunk_a, distance=1.0), _result(chunk_b, distance=1.0)]
+
+ # chunk_a is reachable at depth 1 AND depth 2 — minimum (depth 1) should be used
+ # chunk_b is only reachable at depth 2
+ subgraph = Subgraph(
+ edges=[
+ _edge("NodeShallow", "X", chunk_id=id_a),
+ _edge("NodeDeep", "X", chunk_id=id_a),
+ _edge("NodeDeep2", "X", chunk_id=id_b),
+ ],
+ depths={"NodeShallow": 1, "NodeDeep": 2, "NodeDeep2": 2},
+ )
+
+ ranked = _rerank(results, subgraph, [])
+
+ assert ranked[0].id == id_a
+
+
+def test_results_are_sorted_descending():
+ ids = [uuid4() for _ in range(4)]
+ chunks = [_chunk(i) for i in ids]
+ results = [
+ _result(chunks[0], distance=0.9),
+ _result(chunks[1], distance=0.1),
+ _result(chunks[2], distance=0.5),
+ _result(chunks[3], distance=0.3),
+ ]
+
+ ranked = _rerank(results, Subgraph(), [])
+
+ scores = [1.0 / (1.0 + r.score) for r in results]
+ expected_order = [c for _, c in sorted(zip(scores, chunks), reverse=True)]
+ assert [c.id for c in ranked] == [c.id for c in expected_order]
diff --git a/tests/test_sqlite_graph_db.py b/tests/test_sqlite_graph_db.py
new file mode 100644
index 0000000..3d5e891
--- /dev/null
+++ b/tests/test_sqlite_graph_db.py
@@ -0,0 +1,147 @@
+from uuid import uuid4
+
+from codeembed.graph_db.models import Edge
+from codeembed.graph_db.sqlite_adapter import SqliteGraphDb
+
+
+def _db() -> SqliteGraphDb:
+ return SqliteGraphDb(":memory:")
+
+
+def _edge(source: str, target: str, relation: str = "CALLS", file_path: str = "f.py") -> Edge:
+ return Edge(source=source, target=target, relation=relation, file_path=file_path, chunk_id=uuid4())
+
+
+def test_expand_nodes_empty_graph_returns_no_edges():
+ db = _db()
+ result = db.expand_nodes(["A"])
+ assert result.edges == []
+
+
+def test_expand_nodes_depth_1_returns_direct_neighbor():
+ db = _db()
+ db.add_edge(_edge("A", "B"))
+
+ result = db.expand_nodes(["A"], max_depth=1)
+
+ sources = {e.source for e in result.edges}
+ targets = {e.target for e in result.edges}
+ assert "A" in sources
+ assert "B" in targets
+
+
+def test_expand_nodes_depth_1_does_not_return_two_hop_neighbor():
+ db = _db()
+ db.add_edge(_edge("A", "B"))
+ db.add_edge(_edge("B", "C"))
+
+ result = db.expand_nodes(["A"], max_depth=1)
+
+ targets = {e.target for e in result.edges}
+ assert "C" not in targets
+
+
+def test_expand_nodes_depth_2_returns_two_hop_neighbor():
+ db = _db()
+ db.add_edge(_edge("A", "B"))
+ db.add_edge(_edge("B", "C"))
+
+ result = db.expand_nodes(["A"], max_depth=2)
+
+ targets = {e.target for e in result.edges}
+ assert "B" in targets
+ assert "C" in targets
+
+
+def test_expand_nodes_depths_are_correct():
+ db = _db()
+ db.add_edge(_edge("A", "B"))
+ db.add_edge(_edge("B", "C"))
+
+ result = db.expand_nodes(["A"], max_depth=2)
+
+ assert result.depths["A"] == 0
+ assert result.depths["B"] == 1
+ assert result.depths["C"] == 2
+
+
+def test_expand_nodes_cycle_does_not_loop():
+ db = _db()
+ db.add_edge(_edge("A", "B"))
+ db.add_edge(_edge("B", "A"))
+
+ result = db.expand_nodes(["A"], max_depth=5)
+
+ assert len(result.edges) == 2
+
+
+def test_expand_nodes_relation_filter_excludes_other_relations():
+ db = _db()
+ db.add_edge(_edge("A", "B", relation="CALLS"))
+ db.add_edge(_edge("A", "C", relation="IMPORTS"))
+
+ result = db.expand_nodes(["A"], max_depth=1, relations={"CALLS"})
+
+ targets = {e.target for e in result.edges}
+ assert "B" in targets
+ assert "C" not in targets
+
+
+def test_expand_nodes_returns_chunk_id_on_edge():
+ db = _db()
+ chunk_id = uuid4()
+ db.add_edge(Edge(source="A", target="B", relation="CALLS", file_path="f.py", chunk_id=chunk_id))
+
+ result = db.expand_nodes(["A"], max_depth=1)
+
+ assert len(result.edges) == 1
+ assert result.edges[0].chunk_id == chunk_id
+
+
+def test_delete_edges_by_file_path_removes_matching_edges():
+ db = _db()
+ db.add_edge(_edge("A", "B", file_path="auth.py"))
+ db.add_edge(_edge("C", "D", file_path="utils.py"))
+
+ db.delete_edges_by_file_path("auth.py")
+
+ result = db.expand_nodes(["A"], max_depth=1)
+ assert result.edges == []
+
+ result = db.expand_nodes(["C"], max_depth=1)
+ assert len(result.edges) == 1
+
+
+def test_delete_edge_removes_specific_edge():
+ db = _db()
+ db.add_edge(_edge("A", "B", relation="CALLS"))
+ db.add_edge(_edge("A", "C", relation="IMPORTS"))
+
+ db.delete_edge("A", "B", "CALLS")
+
+ result = db.expand_nodes(["A"], max_depth=1)
+ targets = {e.target for e in result.edges}
+ assert "B" not in targets
+ assert "C" in targets
+
+
+def test_add_edge_is_idempotent():
+ db = _db()
+ edge = _edge("A", "B")
+ db.add_edge(edge)
+ db.add_edge(edge)
+
+ result = db.expand_nodes(["A"], max_depth=1)
+ assert len(result.edges) == 1
+
+
+def test_expand_nodes_multiple_starting_nodes():
+ db = _db()
+ db.add_edge(_edge("A", "C"))
+ db.add_edge(_edge("B", "D"))
+
+ result = db.expand_nodes(["A", "B"], max_depth=1)
+
+ targets = {e.target for e in result.edges}
+ assert "C" in targets
+ assert "D" in targets