Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 143 additions & 26 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import binascii
import uuid
import weakref
from collections.abc import Callable, Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any, TypeVar
Expand Down Expand Up @@ -56,6 +58,85 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None:
data[k] = v.item()


# Module-level (not methods) so they can be registered with ``weakref.finalize``
# without holding a bound reference to the owning object, which would prevent
# it from ever being collected.
def _drop_scratch_tables(engine: sa.Engine, tables: list[sa.Table]) -> None:
"""Drop scratch tables, swallowing errors (e.g. at interpreter shutdown)."""
for table in tables:
try:
table.drop(engine)
except Exception as exc:
LOG.debug("Failed to drop scratch table %s: %s", table.name, exc)


class _SqlIdSet:
"""A set of ids usable in SQL ``IN`` clauses without overflowing bind limits.

Small sets compile to inline ``col.in_([...])``; larger sets are materialized
into a per-instance scratch table and matched via ``col.in_(SELECT id FROM
scratch)``. The same ``_SqlIdSet`` may be reused against multiple columns —
each call to :meth:`in_clause` emits a fresh expression backed by the same
underlying ids.

``occurrences`` is the maximum number of times the id set will be expanded
in a single compiled statement (e.g. filtering both ``source_id`` and
``target_id`` of an edge table counts as 2). The scratch-table cutoff is
divided by it so that ``len(ids) * occurrences`` stays safely under the
backend's bound-variable limit.
"""

def __init__(
self,
graph: "SQLGraph",
ids: Sequence[int],
*,
occurrences: int = 1,
) -> None:
if hasattr(ids, "tolist"):
ids = ids.tolist()
self._ids: list[int] = list(ids)
self._graph = graph

limit = max(1, graph._sql_chunk_size() // max(1, occurrences))
if len(self._ids) > limit:
self._scratch: sa.Table | None = graph._create_id_scratch_table(self._ids)
else:
self._scratch = None

@property
def ids(self) -> list[int]:
return self._ids

@property
def uses_scratch_table(self) -> bool:
return self._scratch is not None

def in_clause(self, column: sa.ColumnElement) -> sa.ColumnElement[bool]:
if self._scratch is None:
return column.in_(self._ids)
return column.in_(sa.select(self._scratch.c.id))

def close(self) -> None:
if self._scratch is not None:
_drop_scratch_tables(self._graph._engine, [self._scratch])
self._scratch = None

def __enter__(self) -> "_SqlIdSet":
return self

def __exit__(self, *exc: object) -> None:
self.close()


def _close_id_sets(id_sets: list[_SqlIdSet]) -> None:
for id_set in id_sets:
try:
id_set.close()
except Exception as exc:
LOG.debug("Failed to close _SqlIdSet: %s", exc)


def _filter_query(
query: sa.Select,
table: type[DeclarativeBase],
Expand Down Expand Up @@ -99,25 +180,28 @@ def __init__(
self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters)
self._include_targets = include_targets
self._include_sources = include_sources
self._id_sets: list[_SqlIdSet] = []

# creating initial query
self._node_query: sa.Select = sa.select(self._graph.Node)
self._edge_query: sa.Select = sa.select(self._graph.Edge)
node_filtered = False

if node_ids is not None:
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()

self._node_query = self._node_query.filter(self._graph.Node.node_id.in_(node_ids))
# The id set is expanded once for Node.node_id, plus once each for
# Edge.target_id / Edge.source_id when those filters are not
# suppressed by ``include_targets`` / ``include_sources``. The
# scratch-table cutoff is divided accordingly so that the total
# number of bound variables stays under the backend's limit.
occurrences = 1 + (not self._include_targets) + (not self._include_sources)
id_set = _SqlIdSet(self._graph, node_ids, occurrences=occurrences)
self._id_sets.append(id_set)

self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id))
if not self._include_targets:
self._edge_query = self._edge_query.filter(
self._graph.Edge.target_id.in_(node_ids),
)
self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.target_id))
if not self._include_sources:
self._edge_query = self._edge_query.filter(
self._graph.Edge.source_id.in_(node_ids),
)
self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.source_id))
node_filtered = True

if self._node_attr_comps:
Expand Down Expand Up @@ -182,6 +266,13 @@ def __init__(

self._node_query = sa.union(*nodes_query)

if self._uses_scratch_tables():
weakref.finalize(self, _close_id_sets, self._id_sets)

def _uses_scratch_tables(self) -> bool:
"""Whether any id set backing this filter materialized a scratch table."""
return any(id_set.uses_scratch_table for id_set in self._id_sets)

@cache_method
def node_ids(self) -> list[int]:
"""
Expand Down Expand Up @@ -1092,19 +1183,18 @@ def overlaps(
"""
Get the overlaps between the nodes in `node_ids`.
"""
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()

with Session(self._engine) as session:
query = session.query(self.Overlap.source_id, self.Overlap.target_id)

if node_ids is not None:
if node_ids is None:
return [[source_id, target_id] for source_id, target_id in query.all()]

with _SqlIdSet(self, node_ids, occurrences=2) as id_set:
query = query.filter(
self.Overlap.source_id.in_(node_ids),
self.Overlap.target_id.in_(node_ids),
id_set.in_clause(self.Overlap.source_id),
id_set.in_clause(self.Overlap.target_id),
)

return [[source_id, target_id] for source_id, target_id in query.all()]
return [[source_id, target_id] for source_id, target_id in query.all()]

def has_overlaps(self) -> bool:
"""
Expand Down Expand Up @@ -1794,6 +1884,33 @@ def _sql_chunk_size(self) -> int:

return chunk_size

def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table:
"""Create a uniquely-named helper table holding ``ids``.

Used to work around SQL bound-variable limits when filtering by large
``IN (...)`` lists: callers replace ``col.in_(ids)`` with
``col.in_(sa.select(table.c.id))``. The caller owns the returned table
and is responsible for dropping it.
"""
unique_ids = list({int(v) for v in ids})

name = f"_tracksdata_ids_{uuid.uuid4().hex}"
table = sa.Table(
name,
sa.MetaData(),
sa.Column("id", sa.BigInteger, primary_key=True),
)
table.create(self._engine)

chunk_size = max(1, self._sql_chunk_size())
with self._engine.begin() as conn:
for i in range(0, len(unique_ids), chunk_size):
conn.execute(
table.insert(),
[{"id": v} for v in unique_ids[i : i + chunk_size]],
)
return table

def _update_table(
self,
table_class: type[DeclarativeBase],
Expand Down Expand Up @@ -2009,18 +2126,18 @@ def _get_degree(
return int(session.execute(stmt).scalar())

stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col)
if node_ids is not None:
stmt = stmt.where(edge_key_col.in_(node_ids))

with Session(self._engine) as session:
# get the number of edges for each using group by and count
degree = dict(session.execute(stmt).all())

if node_ids is None:
# this is necessary to make sure it's the same order as node_ids
with Session(self._engine) as session:
degree = dict(session.execute(stmt).all())
# preserve the canonical node ordering
return [degree.get(node_id, 0) for node_id in self.node_ids()]

return [degree.get(node_id, 0) for node_id in node_ids]
with _SqlIdSet(self, node_ids, occurrences=1) as id_set:
stmt = stmt.where(id_set.in_clause(edge_key_col))
with Session(self._engine) as session:
degree = dict(session.execute(stmt).all())
return [degree.get(node_id, 0) for node_id in id_set.ids]

def in_degree(self, node_ids: list[int] | int | None = None) -> list[int] | int:
"""
Expand Down
87 changes: 87 additions & 0 deletions src/tracksdata/graph/_test/test_subgraph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import re
from collections.abc import Callable
from contextlib import contextmanager
Expand Down Expand Up @@ -1302,3 +1303,89 @@ def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None:
)
)
assert edge_list == expected_edge_list


def _build_chain_graph(graph: SQLGraph, n_nodes: int) -> list[int]:
node_ids: list[int] = []
for t in range(n_nodes):
node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t}))
for src, tgt in itertools.pairwise(node_ids):
graph.add_edge(src, tgt, {})
graph.add_overlap(node_ids[0], node_ids[1])
graph.add_overlap(node_ids[2], node_ids[3])
return node_ids


def _scratch_table_count(graph: SQLGraph) -> int:
"""Count leftover ``_tracksdata_ids_*`` scratch tables in a SQLite graph."""
import sqlalchemy as sa

with graph._engine.connect() as conn:
return conn.execute(
sa.text("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name LIKE '_tracksdata_ids_%'")
).scalar()


def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Filtering with more ids than SQLite's variable limit must not raise.

Reproduces the ``OperationalError: too many SQL variables`` failure by
forcing the scratch-table code path via a tiny chunk size.
"""
graph = SQLGraph("sqlite", str(tmp_path / "scratch.db"))
n_nodes = 40
node_ids = _build_chain_graph(graph, n_nodes)

# Force scratch-table path on every call site.
monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4)

# Context-manager paths (overlaps, _get_degree) must drop their scratch
# tables once the block exits — the count should return to baseline after
# each call regardless of whether the scratch path fired inside.
assert _scratch_table_count(graph) == 0
in_deg = graph.in_degree(node_ids)
assert _scratch_table_count(graph) == 0
out_deg = graph.out_degree(node_ids)
assert _scratch_table_count(graph) == 0
overlaps = graph.overlaps(node_ids)
assert _scratch_table_count(graph) == 0

assert sum(in_deg) == n_nodes - 1
assert sum(out_deg) == n_nodes - 1
assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])])

filtered = graph.filter(node_ids=node_ids)
# Confirm the scratch-table code path was taken rather than raw IN (...).
assert filtered._uses_scratch_tables()
subgraph = filtered.subgraph()
assert subgraph.num_nodes() == n_nodes
assert subgraph.num_edges() == n_nodes - 1


def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
"""The scratch cutoff must account for how many times ids appear per statement.

With ``_sql_chunk_size() == 12`` and ``SQLFilter`` using ``occurrences=3``,
a list of 5 ids would compile to ~15 bound variables — above the limit —
even though ``len(node_ids) <= chunk_size``. The helper must still switch
to the scratch-table path in that band.
"""
graph = SQLGraph("sqlite", str(tmp_path / "scratch.db"))
n_nodes = 5
node_ids = _build_chain_graph(graph, n_nodes)

monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 12)

filtered = graph.filter(node_ids=node_ids)
# 5 ids fits under chunk_size=12 inline, but with occurrences=3 the
# effective cutoff is 12 // 3 == 4, so scratch must kick in.
assert filtered._uses_scratch_tables()
subgraph = filtered.subgraph()
assert subgraph.num_nodes() == n_nodes
assert subgraph.num_edges() == n_nodes - 1

# overlaps() uses occurrences=2 → cutoff 6, so len==5 stays inline.
# Still assert it returns the right data regardless of path.
assert sorted(map(tuple, graph.overlaps(node_ids))) == sorted(
[(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]
)
Loading