Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from enum import Enum
from typing import TYPE_CHECKING, Any

import cloudpickle
import numpy as np
import polars as pl
import rustworkx as rx
import sqlalchemy as sa
from sqlalchemy.orm import DeclarativeBase, Query, Session, load_only
from sqlalchemy.orm import DeclarativeBase, Query, Session, aliased, load_only
from sqlalchemy.sql.type_api import TypeEngine

from tracksdata.attrs import AttrComparison, split_attr_comps
Expand Down Expand Up @@ -242,6 +243,26 @@ def _cast_boolean_columns(self, table_class: type[DeclarativeBase], df: pl.DataF
df = df.with_columns(pl.col(col).cast(pl.Boolean))
return df

def _unpickle_bytes_columns(
self,
df: pl.DataFrame,
) -> pl.DataFrame:
"""
Unpickle bytes columns from the database.

Parameters
----------
df : pl.DataFrame
The DataFrame to unpickle the bytes columns from.

Returns
-------
pl.DataFrame
The DataFrame with the bytes columns unpickled.
"""
df = df.with_columns(pl.col(pl.Binary).map_elements(cloudpickle.loads, return_dtype=pl.Object))
return df

def _update_max_id_per_time(self) -> None:
"""
Update the maximum node ID for each time point.
Expand Down Expand Up @@ -619,6 +640,7 @@ def _get_neighbors(
connection=session.connection(),
)
node_df = self._cast_boolean_columns(self.Node, node_df)
node_df = self._unpickle_bytes_columns(node_df)

if single_node:
return node_df
Expand Down Expand Up @@ -785,26 +807,36 @@ def subgraph(

with Session(self._engine) as session:
node_query = session.query(self.Node)
edge_query = session.query(self.Edge)

node_filtered = False

if node_ids is not None:
node_query = node_query.filter(self.Node.node_id.in_(node_ids))
node_filtered = True

edge_query = edge_query.filter(
self.Edge.source_id.in_(node_ids),
self.Edge.target_id.in_(node_ids),
)

if node_attr_comps:
node_query = _filter_query(node_query, self.Node, node_attr_comps)
node_ids = [i for (i,) in node_query.with_entities(self.Node.node_id).all()]
node_filtered = True

# selecting edges
edge_query = session.query(self.Edge)
SourceNode = aliased(self.Node)
TargetNode = aliased(self.Node)

if node_ids is not None:
edge_query = edge_query.filter(
self.Edge.source_id.in_(node_ids),
self.Edge.target_id.in_(node_ids),
edge_query = edge_query.join(
SourceNode,
self.Edge.source_id == SourceNode.node_id,
).join(
TargetNode,
self.Edge.target_id == TargetNode.node_id,
)
edge_query = _filter_query(edge_query, SourceNode, node_attr_comps)
edge_query = _filter_query(edge_query, TargetNode, node_attr_comps)

if edge_attr_comps:
edge_query = _filter_query(edge_query, self.Edge, edge_attr_comps)
Expand Down Expand Up @@ -919,6 +951,7 @@ def node_attrs(
connection=session.connection(),
)
nodes_df = self._cast_boolean_columns(self.Node, nodes_df)
nodes_df = self._unpickle_bytes_columns(nodes_df)

# match node_ids ordering
if node_ids is not None and not nodes_df.is_empty():
Expand Down Expand Up @@ -981,6 +1014,7 @@ def edge_attrs(
connection=session.connection(),
)
edges_df = self._cast_boolean_columns(self.Edge, edges_df)
edges_df = self._unpickle_bytes_columns(edges_df)

if unpack:
edges_df = unpack_array_attrs(edges_df)
Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/nodes/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def from_coordinates(
The mask.
"""
mask = _spherical_mask(radius, len(center))
center = np.round(center)
center = np.round(center).astype(int)

start = center - np.asarray(mask.shape) // 2
end = start + mask.shape
Expand Down
2 changes: 2 additions & 0 deletions src/tracksdata/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
from typing import Any

__all__ = ["Options", "get_options", "options_context", "set_options"]

# Module-private mutable state
_options_stack: list["Options"] = []

Expand Down