Skip to content
2 changes: 1 addition & 1 deletion benchmarks/graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _build_pipeline(
EdgeAttr(DEFAULT_ATTR_KEYS.SOLUTION) == True,
).subgraph(),
),
("assing_tracks", lambda graph: graph.assign_track_ids()),
("assign_tracks", lambda graph: graph.assign_track_ids()),
]


Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/functional/_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def to_napari_format(
solution_graph = graph

tracks_graph = solution_graph.assign_track_ids(output_track_id_key)
dict_graph = {child: parent for parent, child in tracks_graph.edge_list()}
dict_graph = {tracks_graph[child]: tracks_graph[parent] for parent, child in tracks_graph.edge_list()}

spatial_cols = ["z", "y", "x"][-len(shape) + 1 :]

Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/functional/_rx.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _assign_track_ids(
n_tracks = len(track_ids)

tracks_graph = rx.PyDiGraph(node_count_hint=n_tracks, edge_count_hint=n_tracks)
tracks_graph.add_node(0)

node_ids = tracks_graph.add_nodes_from(track_ids)
track_id_to_rx_node_id = _numba_build_dict(
np.asarray(track_ids, dtype=np.int64),
Expand Down
22 changes: 14 additions & 8 deletions src/tracksdata/functional/_test/test_rx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_single_path() -> None:
assert np.array_equal(node_ids, [0, 1, 2])
assert np.array_equal(track_ids, [1, 1, 1])
assert isinstance(tracks_graph, rx.PyDiGraph)
assert tracks_graph.num_nodes() == 1 + 1 # Single track (includes null node (0))
assert tracks_graph.num_nodes() == 1 # Single track


def test_symmetric_branching_path() -> None:
Expand All @@ -53,7 +53,7 @@ def test_symmetric_branching_path() -> None:
assert len(track_ids) == 3
assert len(np.unique(track_ids)) == 3 # Three unique track IDs
assert isinstance(tracks_graph, rx.PyDiGraph)
assert tracks_graph.num_nodes() == 3 + 1 # Three tracks (includes null node (0))
assert tracks_graph.num_nodes() == 3 # Three tracks


def test_asymmetric_branching_path() -> None:
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_asymmetric_branching_path() -> None:
assert len(track_ids) == 4
assert len(np.unique(track_ids)) == 3 # Three unique track IDs
assert isinstance(tracks_graph, rx.PyDiGraph)
assert tracks_graph.num_nodes() == 3 + 1 # Three tracks (includes null node (0))
assert tracks_graph.num_nodes() == 3 # Three tracks


def test_invalid_multiple_parents() -> None:
Expand Down Expand Up @@ -139,15 +139,19 @@ def test_complex_valid_branching() -> None:
np.testing.assert_array_equal(node_ids, [0, 1, 3, 4, 2, 5])
np.testing.assert_array_equal(track_ids, [1, 2, 2, 2, 3, 4])

assert set(tracks_graph.successor_indices(1)) == {2, 3}
assert set(tracks_graph.successor_indices(3)) == {4}
assert set(tracks_graph.successor_indices(tracks_graph.find_node_by_weight(1))) == set(
map(tracks_graph.find_node_by_weight, {2, 3})
)
assert set(tracks_graph.successor_indices(tracks_graph.find_node_by_weight(3))) == set(
map(tracks_graph.find_node_by_weight, {4})
)
assert tracks_graph.num_edges() == 3

assert len(node_ids) == 6
assert len(track_ids) == 6
assert len(np.unique(track_ids)) == 4 # {0, {1, 3, 4}, 2, 5}
assert isinstance(tracks_graph, rx.PyDiGraph)
assert tracks_graph.num_nodes() == 4 + 1 # Five tracks (includes null node (0))
assert tracks_graph.num_nodes() == 4 # Four tracks


def test_three_children() -> None:
Expand All @@ -169,7 +173,9 @@ def test_three_children() -> None:
graph.add_edge(nodes[0], nodes[3], None)

_, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
assert set(tracks_graph.successor_indices(track_ids[0])) == set(track_ids[1:])
track_graphs_node_id = tracks_graph.find_node_by_weight(track_ids[0])
successor_node_ids = tracks_graph.successor_indices(track_graphs_node_id)
assert {tracks_graph[i] for i in successor_node_ids} == set(track_ids[1:])


def test_multiple_roots() -> None:
Expand All @@ -188,4 +194,4 @@ def test_multiple_roots() -> None:
assert len(track_ids) == 4
assert len(np.unique(track_ids)) == 2 # Two unique track IDs
assert isinstance(tracks_graph, rx.PyDiGraph)
assert tracks_graph.num_nodes() == 2 + 1 # Two separate tracks (includes null node (0))
assert tracks_graph.num_nodes() == 2 # Two separate tracks
26 changes: 26 additions & 0 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,32 @@ def bbox_spatial_filter(

return BBoxSpatialFilter(self, frame_attr_key=frame_attr_key, bbox_attr_key=bbox_attr_key)

@abc.abstractmethod
def assign_track_ids(
self,
output_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
reset: bool = True,
track_id_offset: int = 1,
) -> rx.PyDiGraph:
"""
Compute and assign track ids to nodes.

Parameters
----------
output_key : str
The key of the output track id attribute.
reset : bool
Whether to reset the track ids of the graph. If True, the track ids will be reset to -1.
track_id_offset : int
The starting track id, useful when assigning track ids to a subgraph.

Returns
-------
rx.PyDiGraph
A compressed graph (parent -> child) with track ids lineage relationships.
If node_ids is provided, it will only include linages including those nodes.
"""

def tracklet_graph(
self,
track_id_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
Expand Down
47 changes: 0 additions & 47 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from tracksdata.attrs import AttrComparison
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.functional._rx import _assign_track_ids
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.graph._mapped_graph_mixin import MappedGraphMixin
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph, RXFilter
Expand Down Expand Up @@ -495,52 +494,6 @@ def update_edge_attrs(
else:
self._out_of_sync = True

def assign_track_ids(
self,
output_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
reset: bool = True,
track_id_offset: int = 1,
) -> rx.PyDiGraph:
"""
Compute and assign track ids to nodes.

Parameters
----------
output_key : str
The key of the output track id attribute.
reset : bool
Whether to reset all track ids before assigning new ones.
track_id_offset : int
The starting track id, useful when assigning track ids to a subgraph.

Returns
-------
rx.PyDiGraph
A compressed graph (parent -> child) with track ids lineage relationships.
"""
try:
node_ids, track_ids, tracks_graph = _assign_track_ids(self.rx_graph, track_id_offset)
except RuntimeError as e:
raise RuntimeError(
"Are you sure this graph is a valid lineage graph?\n"
"This function expects a solved graph.\n"
"Often used from `graph.subgraph(edge_attr_filter={'solution': True})`"
) from e

node_ids = self._map_to_external(node_ids)

if output_key not in self.node_attr_keys:
self.add_node_attr_key(output_key, -1)
elif reset:
self.update_node_attrs(attrs={output_key: -1})

self.update_node_attrs(
node_ids=node_ids,
attrs={output_key: track_ids},
)

return tracks_graph

def in_degree(self, node_ids: list[int] | int | None = None) -> list[int] | int:
"""
Get the in-degree of a list of nodes.
Expand Down
82 changes: 60 additions & 22 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,11 +1050,13 @@ def update_edge_attrs(
for key, value in attrs.items():
edge_attr[key] = value[i]

# Do I map the previous and re-calculated track IDs?
def assign_track_ids(
self,
output_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
reset: bool = True,
track_id_offset: int = 1,
node_ids: Sequence[int] | None = None,
) -> rx.PyDiGraph:
"""
Compute and assign track ids to nodes.
Expand All @@ -1067,35 +1069,71 @@ def assign_track_ids(
Whether to reset the track ids of the graph. If True, the track ids will be reset to -1.
track_id_offset : int
The starting track id, useful when assigning track ids to a subgraph.
node_ids : Sequence[int] | None
The IDs of the nodes to include in the track id assignment.
If None, all nodes are used.

Returns
-------
rx.PyDiGraph
A compressed graph (parent -> child) with track ids lineage relationships.
"""
try:
node_ids, track_ids, tracks_graph = _assign_track_ids(self.rx_graph, track_id_offset)
except RuntimeError as e:
raise RuntimeError(
"Are you sure this graph is a valid lineage graph?\n"
"This function expects a solved graph.\n"
"Often used from `graph.subgraph(edge_attr_filter={'solution': True})`"
) from e

if output_key not in self.node_attr_keys:
self.add_node_attr_key(output_key, -1)
elif reset:
self.update_node_attrs(node_ids=self.node_ids(), attrs={output_key: -1})

# node_ids are rustworkx graph ids, therefore we don't need node_id mapping
# and we must use RustWorkXGraph for IndexedRXGraph
RustWorkXGraph.update_node_attrs(
self,
node_ids=node_ids,
attrs={output_key: track_ids},
)

return tracks_graph
# If node_ids is not None, get the extended graph so that it include
# all tracklets containing the nodes in the list.
if node_ids is not None:
track_node_ids = set()
active_ids = set(self.node_ids())
while len(active_ids) > 0:
track_node_ids.update(active_ids)
successors = [
df[DEFAULT_ATTR_KEYS.NODE_ID].first()
for df in self._root.successors(node_ids=active_ids).values()
if len(df) == 1
] # Only consider non-branching nodes
predecessors = [
df[DEFAULT_ATTR_KEYS.NODE_ID].first()
for df in self._root.predecessors(node_ids=active_ids).values()
if len(df) == 1 # Only consider non-branching nodes
]
out_degrees = self._root.out_degree(predecessors)
predecessors = [node for node, degree in zip(predecessors, out_degrees, strict=True) if degree == 1]
active_ids = set(successors + predecessors) - track_node_ids

return (
self.filter(node_ids=list(track_node_ids))
.subgraph(node_attr_keys=[output_key], edge_attr_keys=[])
.assign_track_ids(
output_key=output_key,
reset=reset,
track_id_offset=track_id_offset,
)
)
else:
try:
track_node_ids, track_ids, tracks_graph = _assign_track_ids(self.rx_graph, track_id_offset)
except RuntimeError as e:
raise RuntimeError(
"Are you sure this graph is a valid lineage graph?\n"
"This function expects a solved graph.\n"
"Often used from `graph.subgraph(edge_attr_filter={'solution': True})`"
) from e

# For the IndexedRXGraph, we need to map the track_node_ids to the external node ids
if hasattr(self, "_map_to_external"):
track_node_ids = self._map_to_external(track_node_ids)

if output_key not in self.node_attr_keys:
self.add_node_attr_key(output_key, -1)
elif reset:
self.update_node_attrs(attrs={output_key: -1})

self.update_node_attrs(
node_ids=track_node_ids,
attrs={output_key: track_ids},
)

return tracks_graph

def in_degree(self, node_ids: list[int] | int | None = None) -> list[int] | int:
"""
Expand Down
38 changes: 38 additions & 0 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,44 @@ def _get_degree(

return [degree.get(node_id, 0) for node_id in node_ids]

def assign_track_ids(
self,
output_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
reset: bool = True,
track_id_offset: int = 1,
) -> rx.PyDiGraph:
"""
Compute and assign track ids to nodes.

Parameters
----------
output_key : str
The key of the output track id attribute.
reset : bool
Whether to reset the track ids of the graph. If True, the track ids will be reset to -1.
track_id_offset : int
The starting track id, useful when assigning track ids to a subgraph.

Returns
-------
rx.PyDiGraph
A compressed graph (parent -> child) with track ids lineage relationships.
If node_ids is provided, it will only include linages including those nodes.
"""
if output_key in self.node_attr_keys:
node_attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T, output_key]
else:
node_attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T]
return (
self.filter()
.subgraph(node_attr_keys=node_attr_keys)
.assign_track_ids(
output_key=output_key,
reset=reset,
track_id_offset=track_id_offset,
)
)

def in_degree(self, node_ids: list[int] | int | None = None) -> list[int] | int:
"""
Get the in-degree of a list of nodes.
Expand Down
Loading
Loading