In [140]:
from laptrack import LapTrack
from laptrack import data_conversion
from laptrack import datasets
import networkx as nx
from typing import Optional, Sequence, List, Tuple
from laptrack._typing_utils import NumArray

In [160]:
lt = LapTrack(
    metric="sqeuclidean",
    cutoff=15**2,
    # splitting_metric="sqeuclidean",
    # merging_metric="sqeuclidean",
    splitting_cutoff=15**2,
    merging_cutoff=15**2,
)

In [179]:
spots_df = datasets.simple_tracks()
coords = data_conversion.convert_dataframe_to_coords(
    spots_df, ["position_x", "position_y"]
)
tree1 = lt.predict(coords)
tree1 = nx.relabel_nodes(
    tree1, {node: (int(node[0]), int(node[1])) for node in tree1.nodes}
)
track_df, split_df, merge_df = lt.predict_dataframe(
    spots_df,
    ["position_x", "position_y"],
    frame_col="frame",
)

track_df2, split_df2, merge_df2 = data_conversion.convert_tree_to_dataframe(
    tree1,
    coords,
)
track_df2.columns = track_df.columns
assert track_df.equals(track_df2)
assert split_df.equals(split_df2)
assert merge_df.equals(merge_df2)



In [None]:
import pandas as pd


def convert_dataframes_to_tree_coords(
    track_df: pd.DataFrame,
    split_df: pd.DataFrame,
    merge_df: pd.DataFrame,
    coordinate_cols: List[str],
    frame_col: str = "frame",
) -> Tuple[nx.DiGraph, List[NumArray]]:
    _track_df = track_df.sort_values(frame_col).reset_index()
    coords, frame_index = data_conversion.convert_dataframe_to_coords_frame_index(
        _track_df, coordinate_cols, frame_col=frame_col
    )
    frame_index = [(int(frame), int(ind)) for frame, ind in frame_index]
    grp_inds = {
        _track_id: grp.index.astype(int)
        for _track_id, grp in _track_df.groupby("track_id")
    }

    tree = nx.DiGraph()
    tree.add_nodes_from(frame_index)
    for grp_ind in grp_inds.values():
        if len(grp_ind) > 1:
            for i in range(len(grp_ind) - 1):
                tree.add_edge(frame_index[grp_ind[i]], frame_index[grp_ind[i + 1]])
    for _, row in split_df.iterrows():
        parent_node = grp_inds[row["parent_track_id"]][-1]
        child_node = grp_inds[row["child_track_id"]][0]
        tree.add_edge(
            frame_index[parent_node],
            frame_index[child_node],
        )
    for _, row in merge_df.iterrows():
        parent_node = grp_inds[row["parent_track_id"]][-1]
        child_node = grp_inds[row["child_track_id"]][0]
        tree.add_edge(
            frame_index[parent_node],
            frame_index[child_node],
        )
    return tree, coords


tree2, coords2 = convert_dataframes_to_tree_coords(
    track_df, split_df, merge_df, ["position_x", "position_y"], frame_col="frame"
)

{((5, 0), (6, 0)), ((8, 1), (9, 1)), ((4, 3), (5, 3)), ((4, 0), (5, 2)), ((5, 3), (6, 2)), ((5, 1), (6, 0)), ((7, 2), (9, 0)), ((4, 2), (5, 1)), ((4, 1), (5, 0)), ((5, 2), (6, 1))}
{((5, 0), (6, 2)), ((4, 2), (5, 3)), ((7, 2), (9, 1)), ((5, 1), (6, 1)), ((5, 3), (6, 0)), ((5, 2), (6, 0)), ((8, 1), (9, 0)), ((4, 3), (5, 0)), ((4, 0), (5, 1)), ((4, 1), (5, 2))}


In [None]:
compare_coords_nodes_edges(tree1, tree2, coords, coords2)

True

In [195]:
_tree1 = tree1.copy()
_tree1.add_edge((2, 0), (3, 1))
compare_coords_nodes_edges(_tree1, tree2, coords, coords2)

False

In [162]:
def convert_digraph_to_geff_networkx(
    tree: nx.DiGraph,
    coords: Optional[Sequence[NumArray]] = None,
    attr_names: Optional[List[str]] = None,
) -> nx.DiGraph:
    """Convert the networkx directed graph to a networkx in the GEFF format.

    Parameters
    ----------
    tree : nx.DiGraph
        The directed graph representing the track tree.
    coords : Optional[Sequence[NumArray]], default None
        The coordinate values. If None, no coordinate values are appended
        to the dataframe.
    attr_names : Optional[List[str]], default None
        The list of attribute names to be added to the nodes.
        The length must match the number of coordinates in the `coords` + 1.
        If None, default names of "frame", "coord-0", "coord-1", ... are used.

    Returns
    -------
    geff_tree : nx.Graph
        The undirected graph in the GEFF format, with the following attributes.
        - attr_names[0]
        - attr_names[1]
        - ...

    Example
    -------
    >>> import laptrack as lt
    >>> import laptrack.data_conversion as data_conversion
    >>> tree = lt.predict(coords)
    >>> geff_tree = data_conversion.convert_digraph_to_geff_networkx(tree, coords, attr_names)
    >>> geff.write_nx(geff_tree, "save_path.zarr")

    """

    if attr_names is None:
        attr_names = ["frame"] + [f"coord-{i}" for i in range(coords[0].shape[1])]
    elif len(attr_names) != coords[0].shape[1] + 1:
        raise ValueError(
            f"attr_names must have length {coords[0].shape[1] + 1}, "
            f"but got {len(attr_names)}"
        )
    geff_tree = tree.copy()
    for node in geff_tree.nodes:
        geff_tree.nodes[node]["frame"] = node[0]

    # XXX could be more efficient
    if coords is not None:
        for node in geff_tree.nodes:
            for i, attr_name in enumerate(attr_names[1:], start=0):
                geff_tree.nodes[node][attr_name] = coords[node[0]][node[1], i]
    return geff_tree


geff_tree = convert_digraph_to_geff_networkx(
    tree1, coords, ["frame", "position_x", "position_y"]
)

{((4, 0), (5, 2)),
 ((4, 1), (5, 0)),
 ((4, 2), (5, 1)),
 ((4, 3), (5, 3)),
 ((5, 0), (6, 0)),
 ((5, 1), (6, 0)),
 ((5, 2), (6, 1)),
 ((5, 3), (6, 2)),
 ((7, 2), (9, 0)),
 ((8, 1), (9, 1))}