Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ class Base(DeclarativeBase):
pass

if len(metadata.tables) > 0 and not overwrite:
for table in metadata.tables.values():
self._restore_pickled_column_types(table)
for table_name, table in metadata.tables.items():
cls = type(
table_name,
Expand Down Expand Up @@ -537,6 +539,11 @@ class Metadata(Base):
self.Overlap = Overlap
self.Metadata = Metadata

def _restore_pickled_column_types(self, table: sa.Table) -> None:
for column in table.columns:
if isinstance(column.type, sa.LargeBinary):
column.type = sa.PickleType()

def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict:
return {
**self._boolean_columns[table_class.__tablename__],
Expand Down
21 changes: 21 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,27 @@ def test_compute_overlaps_multiple_timepoints(graph_backend: BaseGraph) -> None:
assert [node1_t0, node2_t0] in valid_overlaps


def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None:
"""Ensure SQLGraph keeps pickled column types after reloading from disk."""
db_path = tmp_path / "mask_graph.db"
graph = SQLGraph("sqlite", str(db_path))
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)

mask_data = np.array([[True, False], [False, True]], dtype=bool)
mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2]))
node_id = graph.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask})

# Dispose engine before reopening to ensure sqlite file is released.
graph._engine.dispose()

reloaded = SQLGraph("sqlite", str(db_path))
reloaded.update_node_attrs(node_ids=[node_id], attrs={DEFAULT_ATTR_KEYS.MASK: [mask]})
stored_mask = reloaded.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK].to_list()[0]

assert isinstance(stored_mask, Mask)
np.testing.assert_array_equal(stored_mask.mask, mask_data)


def test_compute_overlaps_invalid_threshold(graph_backend: BaseGraph) -> None:
"""Test compute_overlaps with invalid threshold values."""
with pytest.raises(ValueError, match=r"iou_threshold must be between 0.0 and 1\.0"):
Expand Down
Loading