Skip to content

Commit

Permalink
Expose batch APIs for linestrips
Browse files Browse the repository at this point in the history
  • Loading branch information
jleibs committed Jul 25, 2023
1 parent c265cf6 commit 2518d11
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 7 deletions.
4 changes: 3 additions & 1 deletion rerun_py/rerun_sdk/rerun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"log_image_file",
"log_line_segments",
"log_line_strip",
"log_line_strips_2d",
"log_line_strips_3d",
"log_mesh",
"log_mesh_file",
"log_meshes",
Expand Down Expand Up @@ -101,7 +103,7 @@
from .log.extension_components import log_extension_components
from .log.file import ImageFormat, MeshFormat, log_image_file, log_mesh_file
from .log.image import log_depth_image, log_image, log_segmentation_image
from .log.lines import log_line_segments, log_line_strip, log_path
from .log.lines import log_line_segments, log_line_strip, log_line_strips_2d, log_line_strips_3d, log_path
from .log.mesh import log_mesh, log_meshes
from .log.points import log_point, log_points
from .log.rects import RectFormat, log_rect, log_rects
Expand Down
16 changes: 12 additions & 4 deletions rerun_py/rerun_sdk/rerun/components/linestrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def from_numpy_arrays(array: Iterable[npt.NDArray[np.float32]]) -> LineStrip2DAr
for line in array:
assert line.shape[1] == 2

offsets = itertools.chain([0], itertools.accumulate(len(line) for line in array))
values = np.concatenate(array) # type: ignore[call-overload]
offsets = list(itertools.chain([0], itertools.accumulate(len(line) for line in array)))
if len(offsets) > 1:
values = np.concatenate(array) # type: ignore[call-overload]
else:
values = np.array([], dtype=np.float32)

fixed = pa.FixedSizeListArray.from_arrays(values.flatten(), type=LineStrip2DType.storage_type.value_type)
storage = pa.ListArray.from_arrays(offsets, fixed, type=LineStrip2DType.storage_type)

Expand All @@ -46,8 +50,12 @@ def from_numpy_arrays(array: Iterable[npt.NDArray[np.float32]]) -> LineStrip3DAr
for line in array:
assert line.shape[1] == 3

offsets = itertools.chain([0], itertools.accumulate(len(line) for line in array))
values = np.concatenate(array) # type: ignore[call-overload]
offsets = list(itertools.chain([0], itertools.accumulate(len(line) for line in array)))
if len(offsets) > 1:
values = np.concatenate(array) # type: ignore[call-overload]
else:
values = np.array([], dtype=np.float32)

fixed = pa.FixedSizeListArray.from_arrays(values.flatten(), type=LineStrip3DType.storage_type.value_type)
storage = pa.ListArray.from_arrays(offsets, fixed, type=LineStrip3DType.storage_type)

Expand Down
223 changes: 221 additions & 2 deletions rerun_py/rerun_sdk/rerun/log/lines.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Iterable

import numpy as np
import numpy.typing as npt
Expand All @@ -12,14 +12,17 @@
from rerun.components.instance import InstanceArray
from rerun.components.linestrip import LineStrip2DArray, LineStrip3DArray
from rerun.components.radius import RadiusArray
from rerun.log import Color, _normalize_colors, _normalize_radii
from rerun.log import Color, Colors, _normalize_colors, _normalize_radii
from rerun.log.error_utils import _send_warning
from rerun.log.extension_components import _add_extension_components
from rerun.log.log_decorator import log_decorator
from rerun.recording_stream import RecordingStream

__all__ = [
"log_path",
"log_line_strip",
"log_line_strips_2d",
"log_line_strips_3d",
"log_line_segments",
]

Expand Down Expand Up @@ -129,6 +132,222 @@ def log_line_strip(
bindings.log_arrow_msg(entity_path, components=instanced, timeless=timeless, recording=recording)


@log_decorator
def log_line_strips_2d(
entity_path: str,
line_strips: Iterable[npt.ArrayLike] | None,
*,
identifiers: npt.ArrayLike | None = None,
stroke_widths: npt.ArrayLike | None = None,
colors: Color | Colors | None = None,
draw_order: float | None = None,
ext: dict[str, Any] | None = None,
timeless: bool = False,
recording: RecordingStream | None = None,
) -> None:
r"""
Log a batch of line strips through 2D space.
Each line strip is a list of points connected by line segments. It can be used to draw
approximations of smooth curves.
The points will be connected in order, like so:
```
2------3 5
/ \ /
0----1 \ /
4
```
Parameters
----------
entity_path:
Path to the path in the space hierarchy
line_strips:
An iterable of Nx2 arrays of points along the path.
To log an empty line_strip use `np.zeros((0,0,3))` or `np.zeros((0,0,2))`
identifiers:
Unique numeric id that shows up when you hover or select the line.
stroke_widths:
Optional widths of the line.
colors:
Optional colors of the lines.
RGB or RGBA in sRGB gamma-space as either 0-1 floats or 0-255 integers, with separate alpha.
draw_order:
An optional floating point value that specifies the 2D drawing order.
Objects with higher values are drawn on top of those with lower values.
The default for lines is 20.0.
ext:
Optional dictionary of extension components. See [rerun.log_extension_components][]
timeless:
If true, the path will be timeless (default: False).
recording:
Specifies the [`rerun.RecordingStream`][] to use.
If left unspecified, defaults to the current active data recording, if there is one.
See also: [`rerun.init`][], [`rerun.set_global_data_recording`][].
"""
recording = RecordingStream.to_native(recording)

colors = _normalize_colors(colors)
radii = _normalize_radii(stroke_widths)
radii = radii / 2.0

identifiers_np = np.array((), dtype="uint64")
if identifiers is not None:
try:
identifiers_np = np.require(identifiers, dtype="uint64")
except ValueError:
_send_warning("Only integer identifiers supported", 1)

# 0 = instanced, 1 = splat
comps = [{}, {}] # type: ignore[var-annotated]

if line_strips is not None:
line_strip_arrs = [np.require(line, dtype="float32") for line in line_strips]
dims = [line.shape[1] for line in line_strip_arrs]

if any(d != 2 for d in dims):
raise ValueError("All line strips must be Nx2")

comps[0]["rerun.linestrip2d"] = LineStrip2DArray.from_numpy_arrays(line_strip_arrs)

if len(identifiers_np):
comps[0]["rerun.instance_key"] = InstanceArray.from_numpy(identifiers_np)

if len(colors):
is_splat = len(colors.shape) == 1
if is_splat:
colors = colors.reshape(1, len(colors))
comps[is_splat]["rerun.colorrgba"] = ColorRGBAArray.from_numpy(colors)

# We store the stroke_width in radius
if len(radii):
is_splat = len(radii) == 1
comps[is_splat]["rerun.radius"] = RadiusArray.from_numpy(radii)

if draw_order is not None:
comps[1]["rerun.draw_order"] = DrawOrderArray.splat(draw_order)

if ext:
_add_extension_components(comps[0], comps[1], ext, identifiers_np)

if comps[1]:
comps[1]["rerun.instance_key"] = InstanceArray.splat()
bindings.log_arrow_msg(entity_path, components=comps[1], timeless=timeless, recording=recording)

# Always the primary component last so range-based queries will include the other data. See(#1215)
bindings.log_arrow_msg(entity_path, components=comps[0], timeless=timeless, recording=recording)


@log_decorator
def log_line_strips_3d(
entity_path: str,
line_strips: Iterable[npt.ArrayLike] | None,
*,
identifiers: npt.ArrayLike | None = None,
stroke_widths: npt.ArrayLike | None = None,
colors: Color | Colors | None = None,
draw_order: float | None = None,
ext: dict[str, Any] | None = None,
timeless: bool = False,
recording: RecordingStream | None = None,
) -> None:
r"""
Log a batch of line strips through 3D space.
Each line strip is a list of points connected by line segments. It can be used to draw approximations
of smooth curves.
The points will be connected in order, like so:
```
2------3 5
/ \ /
0----1 \ /
4
```
Parameters
----------
entity_path:
Path to the path in the space hierarchy
line_strips:
An iterable of Nx3 arrays of points along the path.
To log an empty line_strip use `np.zeros((0,0,3))` or `np.zeros((0,0,2))`
identifiers:
Unique numeric id that shows up when you hover or select the line.
stroke_widths:
Optional widths of the line.
colors:
Optional colors of the lines.
RGB or RGBA in sRGB gamma-space as either 0-1 floats or 0-255 integers, with separate alpha.
draw_order:
An optional floating point value that specifies the 2D drawing order.
Objects with higher values are drawn on top of those with lower values.
The default for lines is 20.0.
ext:
Optional dictionary of extension components. See [rerun.log_extension_components][]
timeless:
If true, the path will be timeless (default: False).
recording:
Specifies the [`rerun.RecordingStream`][] to use.
If left unspecified, defaults to the current active data recording, if there is one.
See also: [`rerun.init`][], [`rerun.set_global_data_recording`][].
"""
recording = RecordingStream.to_native(recording)

colors = _normalize_colors(colors)
radii = _normalize_radii(stroke_widths)
radii = radii / 2.0

identifiers_np = np.array((), dtype="uint64")
if identifiers is not None:
try:
identifiers_np = np.require(identifiers, dtype="uint64")
except ValueError:
_send_warning("Only integer identifiers supported", 1)

# 0 = instanced, 1 = splat
comps = [{}, {}] # type: ignore[var-annotated]

if line_strips is not None:
line_strip_arrs = [np.require(line, dtype="float32") for line in line_strips]
dims = [line.shape[1] for line in line_strip_arrs]

if any(d != 3 for d in dims):
raise ValueError("All line strips must be Nx3")

comps[0]["rerun.linestrip3d"] = LineStrip3DArray.from_numpy_arrays(line_strip_arrs)

if len(identifiers_np):
comps[0]["rerun.instance_key"] = InstanceArray.from_numpy(identifiers_np)

if len(colors):
is_splat = len(colors.shape) == 1
if is_splat:
colors = colors.reshape(1, len(colors))
comps[is_splat]["rerun.colorrgba"] = ColorRGBAArray.from_numpy(colors)

# We store the stroke_width in radius
if len(radii):
is_splat = len(radii) == 1
comps[is_splat]["rerun.radius"] = RadiusArray.from_numpy(radii)

if draw_order is not None:
comps[1]["rerun.draw_order"] = DrawOrderArray.splat(draw_order)

if ext:
_add_extension_components(comps[0], comps[1], ext, identifiers_np)

if comps[1]:
comps[1]["rerun.instance_key"] = InstanceArray.splat()
bindings.log_arrow_msg(entity_path, components=comps[1], timeless=timeless, recording=recording)

# Always the primary component last so range-based queries will include the other data. See(#1215)
bindings.log_arrow_msg(entity_path, components=comps[0], timeless=timeless, recording=recording)


@log_decorator
def log_line_segments(
entity_path: str,
Expand Down

0 comments on commit 2518d11

Please sign in to comment.