Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def node_features(
*,
node_ids: Sequence[int] | None = None,
feature_keys: Sequence[str] | str | None = None,
unpack: bool = False,
) -> pl.DataFrame:
"""
Get the features of the nodes as a pandas DataFrame.
Expand All @@ -299,6 +300,8 @@ def node_features(
feature_keys : Sequence[str] | str | None
The feature keys to get.
If None, all features are used.
unpack : bool
Whether to unpack array features into multiple scalar features.

Returns
-------
Expand All @@ -313,6 +316,7 @@ def edge_features(
node_ids: list[int] | None = None,
feature_keys: Sequence[str] | None = None,
include_targets: bool = False,
unpack: bool = False,
) -> pl.DataFrame:
"""
Get the features of the edges as a polars DataFrame.
Expand All @@ -328,6 +332,8 @@ def edge_features(
include_targets : bool
Whether to include edges out-going from the given node_ids even
if the target node is not in the given node_ids.
unpack : bool
Whether to unpack array features into multiple scalar features.
"""

@property
Expand Down
4 changes: 4 additions & 0 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,12 @@ def node_features(
*,
node_ids: Sequence[int] | None = None,
feature_keys: Sequence[str] | str | None = None,
unpack: bool = False,
) -> pl.DataFrame:
node_dfs = super().node_features(
node_ids=map_ids(self._node_map_from_root, node_ids),
feature_keys=feature_keys,
unpack=unpack,
)
node_dfs = self._map_to_root_df_node_ids(node_dfs)
return node_dfs
Expand All @@ -258,11 +260,13 @@ def edge_features(
node_ids: Sequence[int] | None = None,
feature_keys: Sequence[str] | str | None = None,
include_targets: bool = False,
unpack: bool = False,
) -> pl.DataFrame:
edges_df = super().edge_features(
node_ids=map_ids(self._node_map_from_root, node_ids),
feature_keys=feature_keys,
include_targets=include_targets,
unpack=unpack,
)

edges_df = edges_df.with_columns(
Expand Down
19 changes: 17 additions & 2 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.functional._rx import graph_track_ids
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.utils._dataframe import unpack_array_features
from tracksdata.utils._logging import LOG

if TYPE_CHECKING:
Expand Down Expand Up @@ -386,6 +387,7 @@ def node_features(
*,
node_ids: Sequence[int] | None = None,
feature_keys: Sequence[str] | str | None = None,
unpack: bool = False,
) -> pl.DataFrame:
"""
Get the features of the nodes as a polars DataFrame.
Expand All @@ -398,6 +400,8 @@ def node_features(
feature_keys : Sequence[str] | None
The feature keys to get.
If None, all the features of the first node are used.
unpack : bool
Whether to unpack array features into multiple scalar features.

Returns
-------
Expand Down Expand Up @@ -435,14 +439,20 @@ def node_features(
columns[key] = np.asarray(columns[key])

# Create DataFrame and set node_id as index in one shot
return pl.DataFrame(columns)
df = pl.DataFrame(columns)

if unpack:
df = unpack_array_features(df)

return df

def edge_features(
self,
*,
node_ids: list[int] | None = None,
feature_keys: Sequence[str] | str | None = None,
include_targets: bool = False,
unpack: bool = False,
) -> pl.DataFrame:
"""
Get the features of the edges as a polars DataFrame.
Expand All @@ -458,6 +468,8 @@ def edge_features(
include_targets : bool
Whether to include edges out-going from the given node_ids even
if the target node is not in the given node_ids.
unpack : bool
Whether to unpack array features into multiple scalar features.
"""
if feature_keys is None:
feature_keys = self.edge_features_keys
Expand Down Expand Up @@ -507,7 +519,10 @@ def edge_features(

columns = {k: np.asarray(v) for k, v in columns.items()}

return pl.DataFrame(columns)
df = pl.DataFrame(columns)
if unpack:
df = unpack_array_features(df)
return df

@property
def num_edges(self) -> int:
Expand Down
17 changes: 14 additions & 3 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.utils._dataframe import unpack_array_features
from tracksdata.utils._logging import LOG

if TYPE_CHECKING:
Expand Down Expand Up @@ -415,6 +416,7 @@ def node_features(
*,
node_ids: Sequence[int] | None = None,
feature_keys: Sequence[str] | str | None = None,
unpack: bool = False,
) -> pl.DataFrame:
if isinstance(feature_keys, str):
feature_keys = [feature_keys]
Expand Down Expand Up @@ -448,16 +450,22 @@ def node_features(

# indices are included by default and must be removed
if feature_keys is not None:
return nodes_df.select([pl.col(c) for c in feature_keys])
df = nodes_df.select([pl.col(c) for c in feature_keys])
else:
return nodes_df.drop(DEFAULT_ATTR_KEYS.NODE_ID)
df = nodes_df.drop(DEFAULT_ATTR_KEYS.NODE_ID)

if unpack:
df = unpack_array_features(df)

return df

def edge_features(
self,
*,
node_ids: list[int] | None = None,
feature_keys: Sequence[str] | None = None,
include_targets: bool = False,
unpack: bool = False,
) -> pl.DataFrame:
if isinstance(feature_keys, str):
feature_keys = [feature_keys]
Expand Down Expand Up @@ -499,7 +507,10 @@ def edge_features(
connection=session.connection(),
)

return edges_df
if unpack:
edges_df = unpack_array_features(edges_df)

return edges_df

@property
def node_features_keys(self) -> list[str]:
Expand Down
22 changes: 19 additions & 3 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import numpy as np
import polars as pl
import pytest

Expand Down Expand Up @@ -142,27 +143,42 @@ def test_time_points(graph_backend: BaseGraph) -> None:
def test_node_features(graph_backend: BaseGraph) -> None:
"""Test retrieving node features."""
graph_backend.add_node_feature_key("x", None)
graph_backend.add_node_feature_key("coordinates", np.array([0.0, 0.0]))

node1 = graph_backend.add_node({"t": 0, "x": 1.0})
node2 = graph_backend.add_node({"t": 1, "x": 2.0})
node1 = graph_backend.add_node({"t": 0, "x": 1.0, "coordinates": np.array([10.0, 20.0])})
node2 = graph_backend.add_node({"t": 1, "x": 2.0, "coordinates": np.array([30.0, 40.0])})

df = graph_backend.node_features(node_ids=[node1, node2], feature_keys=["x"])
assert isinstance(df, pl.DataFrame)
assert df["x"].to_list() == [1.0, 2.0]

# Test unpack functionality
df_unpacked = graph_backend.node_features(node_ids=[node1, node2], feature_keys=["coordinates"], unpack=True)
if "coordinates_0" in df_unpacked.columns:
assert df_unpacked["coordinates_0"].to_list() == [10.0, 30.0]
assert df_unpacked["coordinates_1"].to_list() == [20.0, 40.0]


def test_edge_features(graph_backend: BaseGraph) -> None:
"""Test retrieving edge features."""
node1 = graph_backend.add_node({"t": 0})
node2 = graph_backend.add_node({"t": 1})

graph_backend.add_edge_feature_key("weight", 0.0)
graph_backend.add_edge(node1, node2, attributes={"weight": 0.5})
graph_backend.add_edge_feature_key("vector", np.array([0.0, 0.0]))

graph_backend.add_edge(node1, node2, attributes={"weight": 0.5, "vector": np.array([1.0, 2.0])})

df = graph_backend.edge_features(feature_keys=["weight"])
assert isinstance(df, pl.DataFrame)
assert df["weight"].to_list() == [0.5]

# Test unpack functionality
df_unpacked = graph_backend.edge_features(feature_keys=["vector"], unpack=True)
if "vector_0" in df_unpacked.columns:
assert df_unpacked["vector_0"].to_list() == [1.0]
assert df_unpacked["vector_1"].to_list() == [2.0]


def test_edge_features_subgraph_edge_ids(graph_backend: BaseGraph) -> None:
"""Test that edge_features preserves original edge IDs when using node_ids parameter."""
Expand Down
27 changes: 27 additions & 0 deletions src/tracksdata/utils/_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import polars as pl


def unpack_array_features(df: pl.DataFrame) -> pl.DataFrame:
"""
Unpack array features into a dictionary, convert array columns into multiple scalar columns.

Parameters
----------
df : pl.DataFrame
DataFrame with array features.

Returns
-------
pl.DataFrame
DataFrame with unpacked array features.
"""

array_cols = [name for name, dtype in df.schema.items() if isinstance(dtype, pl.Array)]

if len(array_cols) == 0:
return df

for col in array_cols:
df = df.with_columns(pl.col(col).arr.to_struct(lambda x: f"{col}_{x}")).unnest(col) # noqa: B023

return unpack_array_features(df)
33 changes: 33 additions & 0 deletions src/tracksdata/utils/_test/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
import polars as pl

from tracksdata.utils._dataframe import unpack_array_features


def test_unpack_array_features() -> None:
df = pl.DataFrame(
{
"id": [1, 2, 3],
"features": np.asarray([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2]]),
"other": np.asarray([[[3.1, 3.2]], [[4.1, 4.2]], [[5.1, 5.2]]]),
}
)

unpackaged_df = unpack_array_features(df)

expected_df = pl.DataFrame(
{
"id": [1, 2, 3],
"features_0": [0.1, 1.1, 2.1],
"features_1": [0.2, 1.2, 2.2],
"other_0_0": [3.1, 4.1, 5.1],
"other_0_1": [3.2, 4.2, 5.2],
}
)

assert np.all(unpackaged_df.columns == expected_df.columns)

np.testing.assert_array_equal(
unpackaged_df.to_numpy(),
expected_df.to_numpy(),
)
Loading