Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 133 additions & 27 deletions src/plopp/backends/pythreejs/scatter3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import uuid
from typing import Literal
from typing import Any, Literal

import numpy as np
import scipp as sc
Expand Down Expand Up @@ -63,8 +63,6 @@ def __init__(
opacity: float = 1,
pixel_size: sc.Variable | float | None = None,
):
import pythreejs as p3

check_ndim(data, ndim=1, origin='Scatter3d')
self.uid = uid if uid is not None else uuid.uuid4().hex
self._canvas = canvas
Expand All @@ -73,6 +71,10 @@ def __init__(
self._x = x
self._y = y
self._z = z
self._unique_color = to_rgb(f'C{artist_number}' if color is None else color)
self._opacity = opacity
self._new_points = None
self._new_colors = None

# TODO: remove pixel_size in the next release
self._size = size if pixel_size is None else pixel_size
Expand All @@ -88,42 +90,64 @@ def __init__(
dtype=float, unit=self._data.coords[x].unit
).value

self.points = self._make_point_cloud()
self._canvas.add(self.points)

if self._colormapper is not None:
self._colormapper.add_artist(self.uid, self)
colors = self._colormapper.rgba(self.data)[..., :3].astype('float32')
else:
colors = np.broadcast_to(
np.array(to_rgb(f'C{artist_number}' if color is None else color)),
(self._data.coords[self._x].shape[0], 3),
).astype('float32')

self.geometry = p3.BufferGeometry(
def _make_point_cloud(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does this take approximately? Is it as fast as updating a 2d plot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the number of points. I did some basic timings: it's about 0.01s for 100_000 points, and ~0.04s for 500_000 points.
But this is only run if the updated values have a different shape than the existing ones, which I am guessing is not super common.

A more relevant question is probably how long does it take to update the positions every time?
We are now running

        self.geometry.attributes["position"].array = np.array(
            [
                self._data.coords[self._x].values.astype('float32'),
                self._data.coords[self._y].values.astype('float32'),
                self._data.coords[self._z].values.astype('float32'),
            ]
        ).T

on every update, which is potentially quite a large allocation?
Before, we only have one array of floats for the colors, now we have 4 (colors + 3 positions).

We could only update if the coords have changed, but we would have to check something like

if any(not sc.identical(old_coords[dim], self._data.coords[dim]) for dim in "xyz"):

I need to check the timings of such a check, maybe it's fast enough?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, check it before you optimise too much! 0.04s seems fine. I don't think anyone expects 60fps.

Just a guess, but maybe you can make the big allocation slightly cheaper by using np.stack or any of its variants instead of np.array([..]).T.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check can actually be pretty fast (10x or more) compared to setting the position, so I added it in.

"""
Create the point cloud geometry and material.
"""
import pythreejs as p3

self._backup_coords()

geometry = p3.BufferGeometry(
attributes={
'position': p3.BufferAttribute(
array=np.array(
array=np.stack(
[
self._data.coords[self._x].values.astype('float32'),
self._data.coords[self._y].values.astype('float32'),
self._data.coords[self._z].values.astype('float32'),
]
).T
],
axis=1,
)
),
'color': p3.BufferAttribute(
array=np.broadcast_to(
np.array(self._unique_color),
(self._data.coords[self._x].shape[0], 3),
).astype('float32')
),
'color': p3.BufferAttribute(array=colors),
}
)
self._new_positions = None

# TODO: a device pixel_ratio should probably be read from a config file
pixel_ratio = 1.0
# Note that an additional factor of 2.5 (obtained from trial and error) seems to
# be required to get the sizes right in the scene.
self.material = p3.PointsMaterial(
material = p3.PointsMaterial(
vertexColors='VertexColors',
size=2.5 * self._size * pixel_ratio,
transparent=True,
opacity=opacity,
opacity=self._opacity,
depthTest=self._opacity > 0.5,
)
self.points = p3.Points(geometry=self.geometry, material=self.material)
self._canvas.add(self.points)
return p3.Points(geometry=geometry, material=material)

def _backup_coords(self) -> None:
"""
Backup the current coordinates to be able to detect changes.
"""
self._old_coords = {
self._x: self._data.coords[self._x],
self._y: self._data.coords[self._y],
self._z: self._data.coords[self._z],
}

def notify_artist(self, message: str) -> None:
"""
Expand All @@ -135,15 +159,27 @@ def notify_artist(self, message: str) -> None:
message:
The message from the colormapper.
"""
self._update_colors()
self._new_colors = self._colormapper.rgba(self.data)[..., :3].astype('float32')
self._finalize_update()

def _update_colors(self):
def _update_positions(self) -> None:
"""
Set the point cloud's rgba colors:
Update the point cloud's positions from the data.
"""
self.geometry.attributes["color"].array = self._colormapper.rgba(self.data)[
..., :3
].astype('float32')
if all(
sc.identical(self._old_coords[dim], self._data.coords[dim])
for dim in [self._x, self._y, self._z]
):
return
self._backup_coords()
return np.stack(
[
self._data.coords[self._x].values.astype('float32'),
self._data.coords[self._y].values.astype('float32'),
self._data.coords[self._z].values.astype('float32'),
],
axis=1,
)

def update(self, new_values):
"""
Expand All @@ -155,19 +191,89 @@ def update(self, new_values):
New data to update the point cloud values from.
"""
check_ndim(new_values, ndim=1, origin='Scatter3d')
old_shape = self._data.shape
self._data = new_values
if self._colormapper is not None:
self._update_colors()

if self._data.shape != old_shape:
self._new_points = self._make_point_cloud()
else:
self._new_points = None
self._new_positions = self._update_positions()

if self._colormapper is None:
self._finalize_update()

def _finalize_update(self) -> None:
"""
Finalize the update of the point cloud.
This is called either at the end of the position update if there is no
colormapper, and after the colors are updated in the case of a colormapper.
We want to wait for both to be ready before updating the geometry.
"""
# We use the hold context manager to avoid multiple re-draws of the scene and
# thus prevent flickering.
with self._canvas.renderer.hold():
if self._new_points is not None:
self._canvas.remove(self.points)
self.points = self._new_points
if self._new_positions is not None:
self.position = self._new_positions
self._new_positions = None
if self._new_colors is not None:
self.color = self._new_colors
self._new_colors = None
# For some reason, adding the points to the scene before updating the colors
# still shows the old colors for a brief moment, even if hold() is active.
if self._new_points is not None:
self._new_points = None
self._canvas.add(self.points)

@property
def position(self) -> np.ndarray:
"""
The scatter points positions as a (N, 3) numpy array.
"""
return self.geometry.attributes['position'].array

@position.setter
def position(self, val: np.ndarray):
self.geometry.attributes['position'].array = val

@property
def color(self) -> np.ndarray:
"""
The scatter points colors as a (N, 3) numpy array.
"""
return self.geometry.attributes['color'].array

@color.setter
def color(self, val: np.ndarray):
self.geometry.attributes['color'].array = val

@property
def geometry(self) -> Any:
"""
The scatter points geometry.
"""
return self.points.geometry

@property
def material(self) -> Any:
"""
The scatter points material.
"""
return self.points.material

@property
def opacity(self) -> float:
"""
The scatter points opacity.
"""
return self.material.opacity
return self._opacity

@opacity.setter
def opacity(self, val: float):
self._opacity = val
self.material.opacity = val
self.material.depthTest = val > 0.5

Expand Down
2 changes: 2 additions & 0 deletions src/plopp/graphics/graphicalview.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def update(self, *args, **kwargs) -> None:

if self._autoscale:
self.fit_to_data()
elif self.colormapper is not None:
self.colormapper.notify_artists()

self.canvas.draw()

Expand Down
91 changes: 51 additions & 40 deletions src/plopp/widgets/clip3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ def _xor(x: list[sc.Variable]) -> sc.Variable:
}


def select(da: sc.DataArray, s: tuple[str, sc.Variable]) -> sc.DataArray:
return da[s]


class Clip3dTool(ipw.HBox):
"""
A tool that provides a slider to extract a slab of points in a three-dimensional
Expand Down Expand Up @@ -191,6 +187,16 @@ class ClippingPlanes(ipw.HBox):
"""
A widget to make clipping planes for spatial cutting (see :class:`Clip3dTool`) to
make spatial cuts in the X, Y, and Z directions on a three-dimensional scatter plot.
The widget provides buttons to add/remove cuts, toggle the visibility of the cuts,
and set the operation to combine multiple cuts (OR, AND, XOR). The opacity of the
original point clouds is reduced when at least one cut is active, to provide
context.

The selection from all cuts are combined to either create or update a
second point cloud which is included in the scene.
When the position/range of a cut is changed, only the outlines of
the cuts are moved in real time, which is cheap. The actual point cloud gets
updated less frequently using a debounce mechanism.

.. versionadded:: 24.04.0

Expand Down Expand Up @@ -228,7 +234,7 @@ def __init__(self, fig: BaseFig):

self.tabs = ipw.Tab(layout={'width': '550px'})
self._original_nodes = list(self._view.graph_nodes.values())
self._nodes = {}
# self._nodes = {}

self.add_cut_label = ipw.Label('Add cut:')
layout = {'width': '45px', 'padding': '0px 0px 0px 0px'}
Expand Down Expand Up @@ -295,6 +301,15 @@ def __init__(self, fig: BaseFig):
)
self.delete_cut.on_click(self._remove_cut)

self._nodes = {}
self._cut_info_node = Node(self._get_visible_cuts)
for n in self._original_nodes:
self._nodes[n.id] = Node(
self._select_subset, da=n, cuts=self._cut_info_node
)
self._nodes[n.id].add_view(self._view)
self.update_state()

super().__init__(
[
self.tabs,
Expand Down Expand Up @@ -354,6 +369,10 @@ def update_controls(self):
self.opacity.disabled = not at_least_one_cut
opacity = self.opacity.value if at_least_one_cut else 1.0
self._set_opacity({'new': opacity})
# if not at_least_one_cut:
for n in self._original_nodes:
nid = self._nodes[n.id].id
self._view.artists[nid].visible = at_least_one_cut

def _set_opacity(self, change: dict[str, Any]):
"""
Expand Down Expand Up @@ -382,42 +401,34 @@ def change_operation(self, change: dict[str, Any]):
self._operation = change['new'].lower()
self.update_state()

def update_state(self):
def _get_visible_cuts(self) -> list[Clip3dTool]:
"""
Update the state, combining all the active cuts, using the selected binary
operation. The resulting selection is then used to either create or update a
second point cloud which is included in the scene.
The original point cloud is then set to be semi-transparent.
When the position/range of a cut is changed, this function is called via a
debounce mechanism to avoid updating the cloud too often. Only the outlines of
the cuts are moved in real time, which is cheap.
Return the list of visible cuts.
"""
for nodes in self._nodes.values():
self._view.remove(nodes['slice'].id)
nodes['slice'].remove()
self._nodes.clear()
return [cut for cut in self.cuts if cut.visible]

visible_cuts = [cut for cut in self.cuts if cut.visible]
if not visible_cuts:
return
def _select_subset(self, da: sc.DataArray, cuts: list[Clip3dTool]) -> sc.DataArray:
"""
Return the subset of the data array selected by the cuts, combined using the
selected operation.
"""
selections = []
npoints = 0
for cut in cuts:
xmin, xmax = cut.range
selection = (da.coords[cut.dim] >= xmin) & (da.coords[cut.dim] < xmax)
npoints += selection.sum().value
selections.append(selection)
# If no points are selected, return a dummy selection to avoid issues with
# empty selections.
if npoints == 0:
return da[0:1]
sel = OPERATIONS[self._operation](selections)
return da[sel]

for n in self._original_nodes:
da = n.request_data()
selections = []
for cut in visible_cuts:
xmin, xmax = cut.range
selections.append(
(da.coords[cut.dim] >= xmin) & (da.coords[cut.dim] < xmax)
)
selection = OPERATIONS[self._operation](selections)
if selection.sum().value > 0:
if n.id not in self._nodes:
select_node = Node(selection)
self._nodes[n.id] = {
'select': select_node,
'slice': Node(lambda da, s: da[s], da=n, s=select_node),
}
self._nodes[n.id]['slice'].add_view(self._view)
else:
self._nodes[n.id]['select'].func = lambda: selection # noqa: B023
self._nodes[n.id]['select'].notify_children("")
def update_state(self):
"""
Update the state of the cuts in the figure by triggering the node that
provides the list of visible cuts.
"""
self._cut_info_node.notify_children("")
Loading
Loading