Skip to content

Commit

Permalink
interact2D supports mpl 3.7 (#1127)
Browse files Browse the repository at this point in the history
* mpl-3.7

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support older matplotlib

* add focus to sliders

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update WrightTools/artists/_interact.py

Co-authored-by: Kyle Sunden <git@ksunden.space>

* Update CHANGELOG.md

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kyle Sunden <git@ksunden.space>
  • Loading branch information
3 people committed Jul 6, 2023
1 parent cdff886 commit f23f1e2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

### Fixed
- numpy deprecated the `np.float` alias, so use `np.float64` to be more precise
- artists support matplotlib >= 3.7
- interact2D: fixed bug where sliders did not change appearance on focus

### Changed
- data.join now has MultidimensionalAxisError exception message
Expand Down
66 changes: 52 additions & 14 deletions WrightTools/artists/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


class Focus:
def __init__(self, axes, linewidth=2):
def __init__(self, axes, sliders, linewidth=2):
self.axes = axes
self.sliders = sliders
self.linewidth = linewidth
ax = axes[0]
for side in ["top", "bottom", "left", "right"]:
Expand All @@ -36,6 +37,11 @@ def __call__(self, ax):
if self.focus_axis == ax or ax not in self.axes:
return
else: # set new focus
if self.focus_axis.get_gid() in self.sliders.keys():
self.sliders[self.focus_axis.get_gid()].track.set_facecolor("lightgrey")
if ax.get_gid() in self.sliders.keys():
self.sliders[ax.get_gid()].track.set_facecolor("darkgrey")

for spine in ["top", "bottom", "left", "right"]:
self.focus_axis.spines[spine].set_linewidth(1)
ax.spines[spine].set_linewidth(self.linewidth)
Expand All @@ -50,6 +56,20 @@ def _at_dict(data, sliders, xaxis, yaxis):
}


def create_local_global_radio(ax, local):
if mpl.__version_info__ >= (3, 7):
radio = RadioButtons(ax, (" global", " local"), radio_props={"s": 100})
else:
radio = RadioButtons(ax, (" global", " local"))
for circle in radio.circles:
circle.set_radius(0.14)
if local:
radio.set_active(1)
else:
radio.set_active(0)
return radio


def get_axes(data, axes):
xaxis, yaxis = axes
if type(xaxis) in [int, str]:
Expand Down Expand Up @@ -226,24 +246,29 @@ def interact2D(
ydir = 1 if yaxis.points.flatten()[-1] - yaxis.points.flatten()[0] > 0 else -1
current_state.bin_vs_x = True
current_state.bin_vs_y = True

# create buttons
current_state.local = local
radio = RadioButtons(ax_local, (" global", " local"))
if local:
radio.set_active(1)
else:
radio.set_active(0)
for circle in radio.circles:
circle.set_radius(0.14)
radio = create_local_global_radio(ax_local, local)

# create sliders
sliders = {}
for axis in data.axes:
if axis not in [xaxis, yaxis]:
if axis.size > np.prod(axis.shape):
raise NotImplementedError("Cannot use multivariable axis as a slider")
slider_axes = plt.subplot(gs[~len(sliders), :]).axes
slider = Slider(slider_axes, axis.label, 0, axis.points.size - 1, valinit=0, valstep=1)
slider = Slider(
slider_axes,
axis.label,
0,
axis.points.size - 1,
valinit=0,
valstep=1,
track_color="lightgrey",
)
sliders[axis.natural_name] = slider
slider_axes.set_gid(axis.natural_name)
slider.ax.vlines(
range(axis.points.size - 1),
*slider.ax.get_ylim(),
Expand All @@ -252,7 +277,7 @@ def interact2D(
alpha=0.5,
)
slider.valtext.set_text(gen_ticklabels(axis.points)[0])
current_state.focus = Focus([ax0] + [slider.ax for slider in sliders.values()])
current_state.focus = Focus([ax0] + [slider.ax for slider in sliders.values()], sliders)
# initial xyz start are from zero indices of additional axes
current_state.dat = data.chop(
xaxis.natural_name,
Expand Down Expand Up @@ -365,6 +390,9 @@ def draw_sideplot_projections():
if channel.signed:
sp_x.set_ylim(-1.1, 1.1)
sp_y.set_xlim(-1.1, 1.1)
else:
sp_x.set_ylim(0, 1.1)
sp_y.set_xlim(0, 1.1)

def update_sideplot_slices():
# TODO: if bins is only available along one axis, slicing should be valid along the other
Expand Down Expand Up @@ -436,16 +464,26 @@ def update_slider(info, use_imshow=use_imshow):
ticks = norm_to_ticks(norm)
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar.set_ticklabels(ticklabels)
sp_x.collections.clear()
sp_y.collections.clear()

[item.remove() for item in sp_x.collections]
[item.remove() for item in sp_y.collections]
if len(sp_x.collections) > 0: # mpl < 3.7
sp_x.collections.clear()
sp_y.collections.clear()

if channel.signed:
sp_x.set_ylim(-1.1, 1.1)
sp_y.set_xlim(-1.1, 1.1)
else:
sp_x.set_ylim(0, 1.1)
sp_y.set_xlim(0, 1.1)

draw_sideplot_projections()
if line_sp_x.get_visible() and line_sp_y.get_visible():
update_sideplot_slices()
fig.canvas.draw_idle()

def update_crosshairs(xarg, yarg, hide=False):
# if x0 is None or y0 is None:
# raise TypeError((x0, y0))
# find closest x and y pts in dataset
current_state.xarg = xarg
current_state.yarg = yarg
Expand Down
10 changes: 5 additions & 5 deletions tests/artists/test_interact2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def test_4D():

data = wt.data.Data(name="data")
data.create_channel("signal", values=signal, signed=True)
data.create_variable("w1", values=w1[:, None, None, None], units="wn")
data.create_variable("w2", values=w2[None, :, None, None], units="wn")
data.create_variable("w3", values=w3[None, None, :, None], units="wn")
data.create_variable("d1", values=tau[None, None, None, :], units="ps")
data.create_variable("w_1", values=w1[:, None, None, None], units="wn")
data.create_variable("w_2", values=w2[None, :, None, None], units="wn")
data.create_variable("w_3", values=w3[None, None, :, None], units="wn")
data.create_variable("d_1", values=tau[None, None, None, :], units="ps")

data.transform("w1", "w2", "w3", "d1")
data.transform("w_1", "w_2", "w_3", "d_1")
wt.artists.interact2D(data, xaxis=0, yaxis=1, local=True)


Expand Down

0 comments on commit f23f1e2

Please sign in to comment.