Skip to content

Commit

Permalink
Finish FlexibleRect
Browse files Browse the repository at this point in the history
  • Loading branch information
sco1 committed Jun 27, 2024
1 parent 7e23b61 commit 50b2533
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ repos:
- id: python-check-blanket-type-ignore
- id: python-use-type-annotations
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.9
rev: v0.4.10
hooks:
- id: ruff
77 changes: 65 additions & 12 deletions matplotlib_window/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing as t
from collections import abc
from enum import StrEnum
from functools import partial

from matplotlib.axes import Axes
from matplotlib.backend_bases import Event, FigureCanvasBase, MouseEvent
Expand All @@ -8,7 +10,7 @@
from numpy import typing as npt

COORD_T: t.TypeAlias = tuple[float, float]
CALLBACK_T: t.TypeAlias = t.Callable[[Event], t.Any]
CALLBACK_T: t.TypeAlias = abc.Callable[[Event], t.Any]
PLOT_OBJ_T: t.TypeAlias = Line2D | Rectangle
NUMERIC_T: t.TypeAlias = float | int

Expand Down Expand Up @@ -44,6 +46,7 @@ class _DraggableObject:
# Defined by child classes prior to registration
on_motion: CALLBACK_T
snap_to: Line2D | None
redraw_callback: abc.Callable[[], None] | None

# Defined on registration
myobj: PLOT_OBJ_T
Expand Down Expand Up @@ -150,7 +153,7 @@ def disconnect(self) -> None:
self.clicked = False
self.parent_canvas.mpl_disconnect(self.mouse_motion)
self.parent_canvas.mpl_disconnect(self.click_release)
self.parent_canvas.draw()
self._redraw()

def validate_snap_to(self, snap_to: Line2D | None) -> Line2D | None:
"""
Expand All @@ -177,6 +180,12 @@ def _disable_click(self) -> None:
self.parent_canvas.mpl_disconnect(self.click_press)
self.click_press = -1

def _redraw(self) -> None:
if self.redraw_callback is not None:
self.redraw_callback()

self.parent_canvas.draw()


def limit_drag(plotted_data: npt.ArrayLike, query: float) -> float:
"""Clamp the query value within the bounds of the provided dataset."""
Expand All @@ -198,6 +207,10 @@ class DragLine(_DraggableObject):
`snap_to` may be optionally specified as an instance of another `Line2D` object to prevent
dragging of the line beyond the extent of the plotted data.
`redraw_callback` may be optionally specified as a callable which gets called whenever the
location of the line has been changed. This callable is expected to take no arguments and has no
return.
All kwargs not explicitly named by `__init__` are passed through to the `Line2D` constructor,
allowing the user to specify custom line formatting in a form expected by `Line2D`.
"""
Expand All @@ -208,10 +221,12 @@ def __init__(
position: NUMERIC_T,
orientation: Orientation = Orientation.VERTICAL,
snap_to: Line2D | None = None,
redraw_callback: abc.Callable[[], None] | None = None,
color: str = "limegreen",
**kwargs: t.Any,
) -> None:
self.orientation = orientation
self.redraw_callback = redraw_callback

line_pos = (position, position) # matplotlib expectes a coordinate pair
if orientation == Orientation.HORIZONTAL:
Expand Down Expand Up @@ -252,9 +267,9 @@ def on_motion(self, event: Event) -> t.Any:
if self.snap_to:
new_pos = limit_drag(self.snap_to.get_ydata(), event.ydata)
else:
new_pos = event.xdata
new_pos = event.ydata

self.myobj.set_xdata(self.parent_axes.get_xlim())
self.myobj.set_ydata((new_pos, new_pos))
elif self.orientation == Orientation.VERTICAL:
if self.snap_to:
new_pos = limit_drag(self.snap_to.get_xdata(), event.xdata)
Expand All @@ -263,7 +278,7 @@ def on_motion(self, event: Event) -> t.Any:

self.myobj.set_xdata((new_pos, new_pos))

self.parent_canvas.draw()
self._redraw()

def limit_change(self, ax: Axes) -> None:
"""
Expand All @@ -276,7 +291,7 @@ def limit_change(self, ax: Axes) -> None:
else:
self.myobj.set_ydata(ax.get_ylim())

self.parent_canvas.draw()
self._redraw()

def validate_snap_to(self, snap_to: Line2D | None) -> Line2D | None:
"""
Expand Down Expand Up @@ -343,6 +358,10 @@ class DragRect(_DraggableObject):
`snap_to` may be optionally specified as an instance of a `Line2D` object to prevent dragging of
the rectangle beyond the extent of the plotted data.
`redraw_callback` may be optionally specified as a callable which gets called whenever the
location of the line has been changed. This callable is expected to take no arguments and has no
return.
All kwargs not explicitly named by `__init__` are passed through to the `Rectangle` constructor,
allowing the user to specify custom line formatting in a form expected by `Rectangle`.
Expand All @@ -355,6 +374,7 @@ def __init__(
position: NUMERIC_T,
width: NUMERIC_T,
snap_to: Line2D | None = None,
redraw_callback: abc.Callable[[], None] | None = None,
edgecolor: str | None = "limegreen",
facecolor: str = "limegreen",
alpha: NUMERIC_T = 0.4,
Expand All @@ -363,6 +383,8 @@ def __init__(
if width <= 0:
raise ValueError(f"Width value must be greater than 0. Received: {width}")

self.redraw_callback = None

# Rectangle patches are located from their bottom left corner; because we want to span the
# full y range, we need to translate the y position to the bottom of the axes
rect_params = transform_rect_params(ax, position)
Expand Down Expand Up @@ -424,7 +446,7 @@ def on_motion(self, event: Event) -> t.Any:
rect_params = transform_rect_params(self.parent_axes, new_x)
self.myobj.xy = rect_params.xy

self.parent_canvas.draw()
self._redraw()

def on_release(self, event: Event) -> t.Any:
"""
Expand All @@ -448,7 +470,7 @@ def limit_change(self, ax: Axes) -> None:
"""
rect_params = transform_rect_params(ax, 0) # Doesn't matter what the x is, only need height
self.myobj.set_height(rect_params.height)
self.parent_canvas.draw()
self._redraw()

def validate_snap_to(self, snap_to: Line2D | None) -> Line2D | None:
"""
Expand Down Expand Up @@ -491,6 +513,10 @@ class FlexibleRect:
`snap_to` may be optionally specified as an instance of a `Line2D` object to prevent dragging of
the rectangle beyond the extent of the plotted data.
`redraw_callback` may be optionally specified as a callable which gets called whenever the
location of the line has been changed. This callable is expected to take no arguments and has no
return.
NOTE: Motion is constrained to the x-axis only.
"""

Expand All @@ -500,6 +526,7 @@ def __init__(
position: NUMERIC_T,
width: NUMERIC_T,
snap_to: Line2D | None = None,
redraw_callback: abc.Callable[[], None] | None = None,
allow_face_drag: bool = False,
edgecolor: str = "limegreen",
facecolor: str = "limegreen",
Expand All @@ -508,21 +535,47 @@ def __init__(
if width <= 0:
raise ValueError(f"Width value must be greater than 0. Received: {width}")

self.parent_axes = ax
if ax.figure is not None:
self.parent_canvas = ax.figure.canvas
else:
raise ValueError("I don't know how we got here, but there's no figure.")

self.redraw_callback = redraw_callback

# snap_to validation handled by DragRect & DragLine
# Create edges after face so they're topmost & take click priority
self.face = DragRect(
ax=ax, position=position, width=width, facecolor=facecolor, edgecolor=None, alpha=alpha
)
self.edges = [
DragLine(ax=ax, position=position, color=edgecolor, snap_to=snap_to),
DragLine(ax=ax, position=(position + width), color=edgecolor, snap_to=snap_to),
]

line_p = partial(
DragLine,
ax=ax,
color=edgecolor,
snap_to=snap_to,
redraw_callback=self._respan_face,
)
self.edges = [line_p(position=position), line_p(position=(position + width))]

if not allow_face_drag:
self.face._disable_click()
else:
raise NotImplementedError

def _respan_face(self) -> None:
"""Update face dimensions to span the entirety of the y-axes between the two edges."""
left = min(edge.location for edge in self.edges)
right = max(edge.location for edge in self.edges)

rect_params = transform_rect_params(self.parent_axes, left)
width = right - left

self.face.myobj.set_xy(rect_params.xy)
self.face.myobj.set_width(width)

self.parent_canvas.draw() # Call directly to avoid infinitely spamming the callback

@property
def bounds(self) -> tuple[NUMERIC_T, NUMERIC_T]:
"""Return the x-axis locations of the left & right edges."""
Expand Down
Loading

0 comments on commit 50b2533

Please sign in to comment.