Skip to content

Commit

Permalink
Add draggable Rectangle implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sco1 committed Jun 12, 2024
1 parent 79f9def commit 9db0339
Showing 1 changed file with 161 additions and 24 deletions.
185 changes: 161 additions & 24 deletions matplotlib_window/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from matplotlib.axes import Axes
from matplotlib.backend_bases import Event, FigureCanvasBase, MouseEvent
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from numpy import typing as npt

COORD_T: t.TypeAlias = tuple[float, float]
CALLBACK_T: t.TypeAlias = t.Callable[[Event], t.Any]
PLOT_OBJ_T: t.TypeAlias = Line2D
PLOT_OBJ_T: t.TypeAlias = Line2D | Rectangle
NUMERIC_T: t.TypeAlias = float | int

COMMON_OBJ_ID = "dragobj" # Label created object(s) URL for downstream event filtering
Expand Down Expand Up @@ -42,10 +43,10 @@ class _DraggableObject:

# Defined by child classes prior to registration
on_motion: CALLBACK_T
snap_to: PLOT_OBJ_T | None
snap_to: Line2D | None

# Defined on registration
myobj: Line2D
myobj: PLOT_OBJ_T
parent_axes: Axes
parent_canvas: FigureCanvasBase
# The canvas retains only weak references so retain just in case
Expand Down Expand Up @@ -142,12 +143,45 @@ def on_release(self, event: Event) -> t.Any:
# Type narrowing, matplotlib dispatches a MouseEvent here so shouldn't ever trip this
return

self.disconnect()

def disconnect(self) -> None:
"""Disconnect the callbacks connected by `self.on_click`."""
self.clicked = False
self.parent_canvas.mpl_disconnect(self.mouse_motion)
self.parent_canvas.mpl_disconnect(self.click_release)
self.parent_canvas.draw()


def limit_drag(plotted_data: npt.ArrayLike, query: float) -> float:
"""Clamp the query value within the bounds of the provided dataset."""
# Data series may not be sorted, so use min/max
# I'm not sure how to properly type annotate this right now
min_val, max_val = plotted_data.min(), plotted_data.max() # type: ignore[union-attr]
if query > max_val:
return max_val # type: ignore[no-any-return]
elif query < min_val:
return min_val # type: ignore[no-any-return]
else:
return query


def validate_snap_to(snap_to: Line2D | None) -> Line2D | None:
"""
Validate that the `snap_to` object, if provided, actually contains x data.
If `snap_to` is `None`, or is a plot object that contains x data, it is returned unchanged.
Otherwise an exception is raised.
"""
if snap_to is not None:
try:
snap_to.get_xdata()
except AttributeError as e:
raise ValueError("Cannot provide an empty lineseries to snapto") from e

return snap_to


class DragLine(_DraggableObject):
"""
Draggable `Line2D` instance.
Expand All @@ -162,7 +196,7 @@ class DragLine(_DraggableObject):
def __init__(
self,
ax: Axes,
position: float | int,
position: NUMERIC_T,
orientation: Orientation = Orientation.VERTICAL,
snap_to: Line2D | None = None,
color: str = "limegreen",
Expand All @@ -179,15 +213,7 @@ def __init__(
raise ValueError(f"Unsupported orientation provided: '{orientation}'")

self.register_plot_object(obj, ax)

# If provided, check if snap_to is a valid lineseries with data in it
if snap_to is not None:
try:
snap_to.get_xdata()
except AttributeError as e:
raise ValueError("Cannot provide an empty lineseries to snapto") from e

self.snap_to = snap_to
self.snap_to = validate_snap_to(snap_to)

def on_motion(self, event: Event) -> t.Any:
"""
Expand All @@ -197,6 +223,7 @@ def on_motion(self, event: Event) -> t.Any:
fired. If `self.snap_to` is not `None`, motion of the line will be limited to the extent of
the data plotted by the specified `Line2D`.
"""
self.myobj: Line2D
if not isinstance(event, MouseEvent):
# Type narrowing, matplotlib dispatches a MouseEvent here so shouldn't ever trip this
return
Expand Down Expand Up @@ -227,14 +254,124 @@ def on_motion(self, event: Event) -> t.Any:
self.parent_canvas.draw()


def limit_drag(plotted_data: npt.ArrayLike, query: float) -> float:
"""Clamp the query value within the bounds of the provided dataset."""
# Data series may not be sorted, so use min/max
# I'm not sure how to properly type annotate this right now
min_val, max_val = plotted_data.min(), plotted_data.max() # type: ignore[union-attr]
if query > max_val:
return max_val # type: ignore[no-any-return]
elif query < min_val:
return min_val # type: ignore[no-any-return]
else:
return query
class RectParams(t.NamedTuple): # noqa: D101
xy: COORD_T
height: float


def transform_rect_params(ax: Axes, position: NUMERIC_T) -> RectParams:
"""
Transform the desired x position to full span rectangle parameters.
An xy coordinate pair is calculated that places the lower left corner of the rectangle at the
lower y-axis limit, along with a height value that will cause the rectangle to span the entire
y-axis bounds.
"""
y_lbound, y_ubound = ax.get_ylim()
xy = (position, y_lbound)
height = y_ubound - y_lbound

return RectParams(xy=xy, height=height)


class DragRect(_DraggableObject):
"""
Draggable `Rectangle` instance.
`position` specifies the x-coordinate of the left edge of the rectangle.
`snap_to` may be optionally specified as an instance of a `Line2D` object to prevent dragging of
the rectangle beyond the extent of the plotted data.
All kwargs not explicitly named by `__init__` are passed through to the `Rectangle` constructor,
allowing the user to specify custom line formatting in a form expected by `Rectangle`.
NOTE: Motion is constrained to the x-axis only.
"""

def __init__(
self,
ax: Axes,
position: NUMERIC_T,
width: NUMERIC_T,
snap_to: Line2D | None = None,
edgecolor: str = "limegreen",
facecolor: str = "limegreen",
alpha: NUMERIC_T = 0.4,
**kwargs: t.Any,
) -> None:
# Rectangle patches are located from their bottom left corner; because we want to span the
# full y range, we need to translate the y position to the bottom of the axes
rect_params = transform_rect_params(ax, position)

obj = Rectangle(
xy=rect_params.xy,
width=width,
height=rect_params.height,
edgecolor=edgecolor,
facecolor=facecolor,
alpha=alpha,
**kwargs,
)

self.oldxy = rect_params.xy # Used for drag deltas so the object doesn't jump to cursor
self.register_plot_object(obj, ax)
self.snap_to = validate_snap_to(snap_to)

def on_motion(self, event: Event) -> t.Any:
"""
On motion callback.
Update the position of the rectangle to follow the position of the mouse at the time the
event is fired. If `self.snap_to` is not `None`, motion of the rectangle will be limited to
the extent of the data plotted by the specified `Line2D`.
"""
self.myobj: Rectangle
if not isinstance(event, MouseEvent):
# Type narrowing, matplotlib dispatches a MouseEvent here so shouldn't ever trip this
return
if not self.clicked:
return
if event.inaxes != self.parent_axes:
return
if (event.xdata is None) or (event.ydata is None):
return

# Calculate the new xy position based on the movement of the cursor relative to the location
# of the bottom left corner when the object was clicked on. Because we can click anywhere on
# the patch to begin motion, the patch will jump to the mouse if just using the location of
# the MouseEvent.
old_x, _ = self.oldxy
dx = event.xdata - self.click_x
if self.snap_to:
if dx < 0:
# Moving left, check left edge
query = old_x + dx
new_x = limit_drag(self.snap_to.get_xdata(), query)
else:
# Moving right, check right edge
width = self.myobj.get_width()
query = old_x + width + dx
new_x = limit_drag(self.snap_to.get_xdata(), query) - width
else:
new_x = old_x + dx

rect_params = transform_rect_params(self.parent_axes, new_x)
self.myobj.xy = rect_params.xy
self.myobj.set_height(rect_params.height)

self.parent_canvas.draw()

def on_release(self, event: Event) -> t.Any:
"""
Mouse button release callback.
When the mouse button is released, cache the new corner location & disconnect the callbacks
connected by `self.on_click`.
"""
if not isinstance(event, MouseEvent):
# Type narrowing, matplotlib dispatches a MouseEvent here so shouldn't ever trip this
return

self.oldxy = self.myobj.get_xy()
self.disconnect()

0 comments on commit 9db0339

Please sign in to comment.