Skip to content

Commit

Permalink
Merge pull request #1131 from wright-group/refactor-norm
Browse files Browse the repository at this point in the history
interact2D small feature update
  • Loading branch information
kameyer226 committed Jul 21, 2023
2 parents cc7217f + 407fe50 commit a854e54
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 134 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).
### Fixed
- numpy deprecated the `np.float` alias, so use `np.float64` to be more precise
- artists support matplotlib >= 3.7
- interact2D: fixed bug where sliders did not change appearance on focus
- `interact2D`: fixed bug where sliders did not change appearance on focus
- `interact2D`: fixed buggy side plots windowing

### Changed
- data.join now has MultidimensionalAxisError exception message
- `Data.join` now has MultidimensionalAxisError exception message
- `Axis`: space character ("\s") in expressions are culled.
- fixed `interact2D` bug: channel/axes can now be specified with non-zero index arguments
- `interact2D`: side plots project the extremes along each axis, rather than the average.

### Added
- `interact2D` has informative figure window names
- `Data.translate_to_txt`: serialize channels and variables and write as a text file.
- Python supported versions: add 3.10, 3.11, and drop 3.7

## [3.4.6]

### Fixed

- `Data.chop` : fixed bug where chop did not succeed if axes did not span data ndim

## [3.4.5]
Expand Down
231 changes: 110 additions & 121 deletions WrightTools/artists/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,34 +103,48 @@ def get_colormap(signed):
return cmap


def get_norm(channel, current_state) -> object:
if channel.signed:
if not current_state.local:
norm = mpl.colors.CenteredNorm(vcenter=channel.null, halfrange=channel.mag())
else:
norm = mpl.colors.CenteredNorm(vcenter=channel.null)
norm.autoscale_None(current_state.dat[channel.natural_name][:])
if norm.halfrange == 0:
norm.halfrange = 1
else:
if not current_state.local:
norm = mpl.colors.Normalize(vmin=channel.null, vmax=channel.max())
class Norm:
def __init__(self, channel, current_state):
self.current_state = current_state
self.signed = channel.signed
self.update(channel)

def __call__(self, data):
out = self.norm(data)
return out

def update(self, channel):
if self.signed:
if not self.current_state.local:
norm = mpl.colors.CenteredNorm(vcenter=channel.null, halfrange=channel.mag())
else:
norm = mpl.colors.CenteredNorm(vcenter=channel.null)
norm.autoscale_None(
np.ma.masked_invalid(self.current_state.dat[channel.natural_name][:])
)
if norm.halfrange == 0:
norm.halfrange = 1
else:
norm = mpl.colors.Normalize(vmin=channel.null)
norm.autoscale_None(current_state.dat[channel.natural_name][:])
if norm.vmax == norm.vmin:
norm.vmax += 1
return norm

if not self.current_state.local:
norm = mpl.colors.Normalize(vmin=channel.null, vmax=np.nanmax(channel[:]))
else:
norm = mpl.colors.Normalize(vmin=channel.null)
norm.autoscale_None(
np.ma.masked_invalid(self.current_state.dat[channel.natural_name][:])
)
if norm.vmax == norm.vmin:
norm.vmax += 1
self.norm = norm

def norm_to_ticks(norm) -> np.array:
if type(norm) == mpl.colors.CenteredNorm:
vmin = norm.vcenter - norm.halfrange
vmax = norm.vcenter + norm.halfrange
else: # mpl.colors.Normalize
vmin = norm.vmin
vmax = norm.vmax
return np.linspace(vmin, vmax, 11)
@property
def ticks(self) -> np.array:
if type(self.norm) == mpl.colors.CenteredNorm:
vmin = self.norm.vcenter - self.norm.halfrange
vmax = self.norm.vcenter + self.norm.halfrange
else: # mpl.colors.Normalize
vmin = self.norm.vmin
vmax = self.norm.vmax
return np.linspace(vmin, vmax, 11)


def gen_ticklabels(points, signed=None):
Expand All @@ -153,16 +167,6 @@ def gen_ticklabels(points, signed=None):
return ticklabels


def norm(arr, signed, ignore_zero=True):
if signed:
norm = np.nanmax(np.abs(arr))
else:
norm = np.nanmax(arr)
if norm != 0 and ignore_zero:
arr /= norm
return arr


def interact2D(
data: wt_data.Data,
xaxis=0,
Expand Down Expand Up @@ -214,6 +218,7 @@ def interact2D(
raise DimensionalityError(">= 2", data.ndim)
# TODO: implement aspect; doesn't work currently because of our incorporation of colorbar
fig, gs = create_figure(width="single", nrows=7 + nsliders, cols=[1, 1, 1, 1, 1, "cbar"])
plt.get_current_fig_manager().set_window_title(f"interact2D: {data.natural_name}")
# create axes
ax0 = plt.subplot(gs[1:6, 0:5])
ax0.patch.set_facecolor("w")
Expand Down Expand Up @@ -253,51 +258,46 @@ def interact2D(

# create sliders
sliders = {}
for axis in data.axes:
if axis not in [xaxis, yaxis]:
if axis.size > np.prod(axis.shape):
raise NotImplementedError("Cannot use multivariable axis as a slider")
slider_axes = plt.subplot(gs[~len(sliders), :]).axes
slider = Slider(
slider_axes,
axis.label,
0,
axis.points.size - 1,
valinit=0,
valstep=1,
track_color="lightgrey",
)
sliders[axis.natural_name] = slider
slider_axes.set_gid(axis.natural_name)
slider.ax.vlines(
range(axis.points.size - 1),
*slider.ax.get_ylim(),
colors="k",
linestyle=":",
alpha=0.5,
)
slider.valtext.set_text(gen_ticklabels(axis.points)[0])
for axis in filter(lambda a: a not in [xaxis, yaxis], data.axes):
if axis.size > np.prod(axis.shape):
raise NotImplementedError("Cannot use multivariable axis as a slider")
slider_axes = plt.subplot(gs[~len(sliders), :]).axes
slider = Slider(
slider_axes,
axis.label,
0,
axis.points.size - 1,
valinit=0,
valstep=1,
track_color="lightgrey",
)
sliders[axis.natural_name] = slider
slider_axes.set_gid(axis.natural_name)
slider.ax.vlines(
range(axis.points.size - 1),
*slider.ax.get_ylim(),
colors="k",
linestyle=":",
alpha=0.5,
)
slider.valtext.set_text(gen_ticklabels(axis.points)[0])
current_state.focus = Focus([ax0] + [slider.ax for slider in sliders.values()], sliders)
# initial xyz start are from zero indices of additional axes
current_state.dat = data.chop(
xaxis.natural_name,
yaxis.natural_name,
at=_at_dict(data, sliders, xaxis, yaxis),
verbose=False,
)[0]
norm = get_norm(channel, current_state)
current_state.dat = data.at(**_at_dict(data, sliders, xaxis, yaxis))
current_state.dat.transform(xaxis.expression, yaxis.expression)
current_state.norm = Norm(channel, current_state)

gen_mesh = ax0.pcolormesh if not use_imshow else ax0.imshow
obj2D = gen_mesh(
current_state.dat,
cmap=cmap,
norm=norm,
norm=current_state.norm.norm,
ylabel=yaxis.label,
xlabel=xaxis.label,
)
ax0.grid(True)
# colorbar
ticks = norm_to_ticks(norm)
ticks = current_state.norm.ticks
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar = plot_colorbar(cax, cmap=cmap, label=channel.natural_name, ticks=ticks)
colorbar.set_ticklabels(ticklabels)
Expand All @@ -321,63 +321,57 @@ def draw_sideplot_projections():
)
> 1
).index(True)

norm = current_state.norm

if channel.signed:
temp_arr = np.ma.masked_array(arr, np.isnan(arr), copy=True)
temp_arr[temp_arr < 0] = 0
x_proj_pos = np.nanmean(temp_arr, axis=yind)
y_proj_pos = np.nanmean(temp_arr, axis=xind)
x_proj_pos = np.nanmax(temp_arr, axis=yind)
y_proj_pos = np.nanmax(temp_arr, axis=xind)

temp_arr = np.ma.masked_array(arr, np.isnan(arr), copy=True)
temp_arr[temp_arr > 0] = 0
x_proj_neg = np.nanmean(temp_arr, axis=yind)
y_proj_neg = np.nanmean(temp_arr, axis=xind)
x_proj_neg = np.nanmin(temp_arr, axis=yind)
y_proj_neg = np.nanmin(temp_arr, axis=xind)

x_proj = np.nanmean(arr, axis=yind)
y_proj = np.nanmean(arr, axis=xind)

alpha = 0.4
blue = "#517799" # start with #87C7FF and change saturation
red = "#994C4C" # start with #FF7F7F and change saturation

if current_state.bin_vs_x:
x_proj_norm = max(np.nanmax(x_proj_pos), np.nanmax(-x_proj_neg))
if x_proj_norm != 0:
x_proj_pos /= x_proj_norm
x_proj_neg /= x_proj_norm
x_proj /= x_proj_norm
try:
sp_x.fill_between(xaxis.points, x_proj_pos, 0, color=red, alpha=alpha)
sp_x.fill_between(xaxis.points, 0, x_proj_neg, color=blue, alpha=alpha)
sp_x.fill_between(xaxis.points, x_proj, 0, color="k", alpha=0.3)
sp_x.fill_between(xaxis.points, norm(x_proj_pos), 0.5, color=red, alpha=alpha)
sp_x.fill_between(xaxis.points, 0.5, norm(x_proj_neg), color=blue, alpha=alpha)
sp_x.fill_between(xaxis.points, norm(x_proj), 0.5, color="k", alpha=0.3)
except ValueError: # Input passed into argument is not 1-dimensional
current_state.bin_vs_x = False
sp_x.set_visible(False)
if current_state.bin_vs_y:
y_proj_norm = max(np.nanmax(y_proj_pos), np.nanmax(-y_proj_neg))
if y_proj_norm != 0:
y_proj_pos /= y_proj_norm
y_proj_neg /= y_proj_norm
y_proj /= y_proj_norm
try:
sp_y.fill_betweenx(yaxis.points, y_proj_pos, 0, color=red, alpha=alpha)
sp_y.fill_betweenx(yaxis.points, 0, y_proj_neg, color=blue, alpha=alpha)
sp_y.fill_betweenx(yaxis.points, y_proj, 0, color="k", alpha=0.3)
sp_y.fill_betweenx(yaxis.points, norm(y_proj_pos), 0.5, color=red, alpha=alpha)
sp_y.fill_betweenx(
yaxis.points, 0.5, norm(y_proj_neg), color=blue, alpha=alpha
)
sp_y.fill_betweenx(yaxis.points, norm(y_proj), 0.5, color="k", alpha=0.3)
except ValueError:
current_state.bin_vs_y = False
sp_y.set_visible(False)
else:
if current_state.bin_vs_x:
x_proj = np.nanmean(arr, axis=yind)
x_proj = norm(x_proj, channel.signed)
x_proj = np.nanmax(arr, axis=yind)
try:
sp_x.fill_between(xaxis.points, x_proj, 0, color="k", alpha=0.3)
sp_x.fill_between(xaxis.points, norm(x_proj), 0, color="k", alpha=0.3)
except ValueError:
current_state.bin_vs_x = False
sp_x.set_visible(False)
if current_state.bin_vs_y:
y_proj = np.nanmean(arr, axis=xind)
y_proj = norm(y_proj, channel.signed)
y_proj = np.nanmax(arr, axis=xind)
try:
sp_y.fill_betweenx(yaxis.points, y_proj, 0, color="k", alpha=0.3)
sp_y.fill_betweenx(yaxis.points, norm(y_proj), 0, color="k", alpha=0.3)
except ValueError:
current_state.bin_vs_y = False
sp_y.set_visible(False)
Expand All @@ -387,12 +381,8 @@ def draw_sideplot_projections():
ax0.set_xlim(xaxis.points.min(), xaxis.points.max())
ax0.set_ylim(yaxis.points.min(), yaxis.points.max())

if channel.signed:
sp_x.set_ylim(-1.1, 1.1)
sp_y.set_xlim(-1.1, 1.1)
else:
sp_x.set_ylim(0, 1.1)
sp_y.set_xlim(0, 1.1)
sp_x.set_ylim(0, 1)
sp_y.set_xlim(0, 1)

def update_sideplot_slices():
# TODO: if bins is only available along one axis, slicing should be valid along the other
Expand All @@ -410,28 +400,31 @@ def update_sideplot_slices():

at_dict = _at_dict(data, sliders, xaxis, yaxis)
at_dict[xaxis.natural_name] = (x0, xaxis.units)
side_plot_data = data.chop(yaxis.natural_name, at=at_dict, verbose=False)
side_plot = side_plot_data[0][channel.natural_name].points
side_plot = norm(side_plot, channel.signed)
side_plot_data = data.at(**at_dict)
side_plot = side_plot_data[channel.natural_name].points
side_plot = current_state.norm(side_plot)
line_sp_y.set_data(side_plot, yaxis.points)
side_plot_data.close()

at_dict = _at_dict(data, sliders, xaxis, yaxis)
at_dict[yaxis.natural_name] = (y0, yaxis.units)
side_plot_data = data.chop(xaxis.natural_name, at=at_dict, verbose=False)
side_plot = side_plot_data[0][channel.natural_name].points
side_plot = norm(side_plot, channel.signed)
side_plot_data = data.at(**at_dict)
side_plot = side_plot_data[channel.natural_name].points
side_plot = current_state.norm(side_plot)
line_sp_x.set_data(xaxis.points, side_plot)
side_plot_data.close()

def update_local(index):
if verbose:
print("normalization:", index)
current_state.local = radio.value_selected[1:] == "local"
norm = get_norm(channel, current_state)
obj2D.set_norm(norm)
ticklabels = gen_ticklabels(np.linspace(norm.vmin, norm.vmax, 11), channel.signed)
current_state.norm.update(channel)
obj2D.set_norm(current_state.norm.norm)
ticklabels = gen_ticklabels(current_state.norm.ticks, channel.signed)
colorbar.set_ticklabels(ticklabels)

update_sideplots(sp_x, sp_y, line_sp_x, line_sp_y)

fig.canvas.draw_idle()

def update_slider(info, use_imshow=use_imshow):
Expand All @@ -458,30 +451,26 @@ def update_slider(info, use_imshow=use_imshow):
obj2D.set_data(current_state.dat[channel.natural_name][:].transpose(transpose))
else:
obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
norm = get_norm(channel, current_state)
obj2D.set_norm(norm)
current_state.norm.update(channel)
obj2D.set_norm(current_state.norm.norm)

ticks = norm_to_ticks(norm)
ticks = current_state.norm.ticks
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar.set_ticklabels(ticklabels)

update_sideplots(sp_x, sp_y, line_sp_x, line_sp_y)
fig.canvas.draw_idle()

def update_sideplots(sp_x, sp_y, line_sp_x, line_sp_y):
[item.remove() for item in sp_x.collections]
[item.remove() for item in sp_y.collections]
if len(sp_x.collections) > 0: # mpl < 3.7
sp_x.collections.clear()
sp_y.collections.clear()

if channel.signed:
sp_x.set_ylim(-1.1, 1.1)
sp_y.set_xlim(-1.1, 1.1)
else:
sp_x.set_ylim(0, 1.1)
sp_y.set_xlim(0, 1.1)

draw_sideplot_projections()
if line_sp_x.get_visible() and line_sp_y.get_visible():
update_sideplot_slices()
fig.canvas.draw_idle()

def update_crosshairs(xarg, yarg, hide=False):
# find closest x and y pts in dataset
Expand Down

0 comments on commit a854e54

Please sign in to comment.