Skip to content

Commit

Permalink
Merge pull request #1073 from wright-group/interact2d-cmap
Browse files Browse the repository at this point in the history
interact2d cmap options
  • Loading branch information
kameyer226 committed Jun 17, 2022
2 parents 2fecf16 + f89eac6 commit 0b07857
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

## Added
### Added
- `artists.interact2D` supports `cmap` kwarg.
- iPython integration: autocomplete includes axis, variable, and channel names

### Changed
- `artists.interact2D` uses matplotlib norm objects to control colormap scaling

### Fixed
- `kit.fft`: fixed bug where Fourier coefficients were off by a scalar factor.

Expand Down
94 changes: 54 additions & 40 deletions WrightTools/artists/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,42 @@ def get_channel(data, channel):
return channel


def get_colormap(channel):
if channel.signed:
cmap = "signed"
else:
cmap = "default"
def get_colormap(signed):
cmap = "signed" if signed else "default"
cmap = colormaps[cmap]
cmap.set_bad([0.75] * 3, 1.0)
cmap.set_under([0.75] * 3, 1.0)
return cmap


def get_clim(channel, current_state):
if current_state.local:
arr = current_state.dat[channel.natural_name][:]
if channel.signed:
mag = np.nanmax(np.abs(arr))
clim = [-mag, mag]
def get_norm(channel, current_state) -> object:
if channel.signed:
if not current_state.local:
norm = mpl.colors.CenteredNorm(vcenter=channel.null, halfrange=channel.mag())
else:
clim = [0, np.nanmax(arr)]
norm = mpl.colors.CenteredNorm(vcenter=channel.null)
norm.autoscale_None(current_state.dat[channel.natural_name][:])
if norm.halfrange == 0:
norm.halfrange = 1
else:
if channel.signed:
clim = [-channel.mag(), channel.mag()]
if not current_state.local:
norm = mpl.colors.Normalize(vmin=channel.null, vmax=channel.max())
else:
clim = [0, channel.max()]
return clim
norm = mpl.colors.Normalize(vmin=channel.null)
norm.autoscale_None(current_state.dat[channel.natural_name][:])
if norm.vmax == norm.vmin:
norm.vmax += 1
return norm


def norm_to_ticks(norm) -> np.array:
if type(norm) == mpl.colors.CenteredNorm:
vmin = norm.vcenter - norm.halfrange
vmax = norm.vcenter + norm.halfrange
else: # mpl.colors.Normalize
vmin = norm.vmin
vmax = norm.vmax
return np.linspace(vmin, vmax, 11)


def gen_ticklabels(points, signed=None):
Expand Down Expand Up @@ -133,7 +144,14 @@ def norm(arr, signed, ignore_zero=True):


def interact2D(
data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, use_imshow=False, verbose=True
data: wt_data.Data,
xaxis=0,
yaxis=1,
channel=0,
cmap=None,
local=False,
use_imshow=False,
verbose=True,
):
"""Interactive 2D plot of the dataset.
Side plots show x and y projections of the slice (shaded gray).
Expand All @@ -151,10 +169,12 @@ def interact2D(
Expression or index of y axis. Default is 1.
channel : string, integer, or data.Channel object (optional)
Name or index of channel to plot. Default is 0.
cmap : string or cm object (optional)
Name of colormap, or explicit colormap object. Defaults to channel default.
local : boolean (optional)
Toggle plotting locally. Default is False.
use_imshow : boolean (optional)
If true, matplotlib imshow is used to render the 2D slice.
If True, matplotlib imshow is used to render the 2D slice.
Can give better performance, but is only accurate for
uniform grids. Default is False.
verbose : boolean (optional)
Expand All @@ -163,10 +183,10 @@ def interact2D(
# avoid changing passed data object
data = data.copy()
# unpack
data.prune(keep_channels=channel)
data.prune(keep_channels=channel, verbose=False)
channel = get_channel(data, channel)
xaxis, yaxis = get_axes(data, [xaxis, yaxis])
cmap = get_colormap(channel)
cmap = cmap if cmap is not None else get_colormap(channel.signed)
current_state = SimpleNamespace()
# create figure
nsliders = data.ndim - 2
Expand Down Expand Up @@ -229,7 +249,7 @@ def interact2D(
*slider.ax.get_ylim(),
colors="k",
linestyle=":",
alpha=0.5
alpha=0.5,
)
slider.valtext.set_text(gen_ticklabels(axis.points)[0])
current_state.focus = Focus([ax0] + [slider.ax for slider in sliders.values()])
Expand All @@ -240,25 +260,21 @@ def interact2D(
at=_at_dict(data, sliders, xaxis, yaxis),
verbose=False,
)[0]
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
norm = get_norm(channel, current_state)

gen_mesh = ax0.pcolormesh if not use_imshow else ax0.imshow
obj2D = gen_mesh(
current_state.dat,
cmap=cmap,
vmin=clim[0],
vmax=clim[1],
norm=norm,
ylabel=yaxis.label,
xlabel=xaxis.label,
)
ax0.grid(b=True)
# colorbar
colorbar = plot_colorbar(
cax, cmap=cmap, label=channel.natural_name, ticks=np.linspace(clim[0], clim[1], 11)
)
ticks = norm_to_ticks(norm)
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar = plot_colorbar(cax, cmap=cmap, label=channel.natural_name, ticks=ticks)
colorbar.set_ticklabels(ticklabels)
fig.canvas.draw_idle()

Expand Down Expand Up @@ -384,12 +400,10 @@ def update_local(index):
if verbose:
print("normalization:", index)
current_state.local = radio.value_selected[1:] == "local"
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
norm = get_norm(channel, current_state)
obj2D.set_norm(norm)
ticklabels = gen_ticklabels(np.linspace(norm.vmin, norm.vmax, 11), channel.signed)
colorbar.set_ticklabels(ticklabels)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
obj2D.set_clim(*clim)
fig.canvas.draw_idle()

def update_slider(info, use_imshow=use_imshow):
Expand All @@ -416,11 +430,11 @@ def update_slider(info, use_imshow=use_imshow):
obj2D.set_data(current_state.dat[channel.natural_name][:].transpose(transpose))
else:
obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
obj2D.set_clim(*clim)
norm = get_norm(channel, current_state)
obj2D.set_norm(norm)

ticks = norm_to_ticks(norm)
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar.set_ticklabels(ticklabels)
sp_x.collections.clear()
sp_y.collections.clear()
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def read(fname):
name="WrightTools",
packages=find_packages(exclude=("tests", "tests.*")),
package_data=extra_files,
python_requires=">=3.6",
python_requires=">=3.7",
install_requires=[
"h5py",
"imageio",
"matplotlib>=3.3.0",
"matplotlib>=3.4.0",
"numexpr",
"numpy>=1.15.0",
"pint",
Expand Down Expand Up @@ -74,8 +74,9 @@ def read(fname):
"Framework :: Matplotlib",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering",
],
)

0 comments on commit 0b07857

Please sign in to comment.