Skip to content

Commit

Permalink
Merge pull request #1035 from wright-group/artists_imshow
Browse files Browse the repository at this point in the history
artists.imshow
  • Loading branch information
kameyer226 committed Nov 8, 2021
2 parents f1429ac + 9a68e95 commit a9242fd
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 25 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Added
- `artists` now has a wrapper for matplotlib's `imshow`. Make sure to use uniform grids.

### Fixed
- `artists._parse_limits` now recognizes channel `null` for signed data limits.

Expand Down
133 changes: 121 additions & 12 deletions WrightTools/artists/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _parse_limits(self, zi=None, data=None, channel_index=None, dynamic_range=Fa

def _parse_plot_args(self, *args, **kwargs):
plot_type = kwargs.pop("plot_type")
if plot_type not in ["pcolor", "pcolormesh", "contourf", "contour"]:
if plot_type not in ["pcolor", "pcolormesh", "contourf", "contour", "imshow"]:
raise NotImplementedError
args = list(args) # offer pop, append etc
dynamic_range = kwargs.pop("dynamic_range", False)
Expand All @@ -127,15 +127,39 @@ def _parse_plot_args(self, *args, **kwargs):
for sq, xs, ys in zip(squeeze, xa.shape, ya.shape):
if sq and (xs != 1 or ys != 1):
raise wt_exceptions.ValueError("Cannot squeeze axis to fit channel")
squeeze = tuple([0 if i else slice(None) for i in squeeze])
zi = data.channels[channel_index].points
xi = xa.full[squeeze]
yi = ya.full[squeeze]
if plot_type in ["pcolor", "pcolormesh", "contourf", "contour"]:
ndim = 2
if not zi.ndim == ndim:
if not zi.ndim == 2:
raise wt_exceptions.DimensionalityError(ndim, data.ndim)
args = [xi, yi, zi] + args
squeeze = tuple([0 if i else slice(None) for i in squeeze])
if plot_type == "imshow":
if "aspect" not in kwargs.keys():
kwargs["aspect"] = "auto"
if "origin" not in kwargs.keys():
kwargs["origin"] = "lower"
if "interpolation" not in kwargs.keys():
if max(zi.shape) < 10 ** 3: # TODO: better decision logic
kwargs["interpolation"] = "nearest"
else:
kwargs["interpolation"] = "antialiased"
xi = xa[:][squeeze]
yi = ya[:][squeeze]

zi = zi.transpose(_order_for_imshow(xi, yi))
# extract extent
if "extent" not in kwargs.keys():
xlim = [xi[0, 0], xi[-1, -1]]
ylim = [yi[0, 0], yi[-1, -1]]
xstep = (xlim[1] - xlim[0]) / (2 * xi.size)
ystep = (ylim[1] - ylim[0]) / (2 * yi.size)
x_extent = [xlim[0] - xstep, xlim[1] + xstep]
y_extent = [ylim[0] - ystep, ylim[1] + ystep]
extent = [*x_extent, *y_extent]
kwargs["extent"] = extent
args = [zi] + args
else:
xi = xa.full[squeeze]
yi = ya.full[squeeze]
args = [xi, yi, zi] + args
# limits
kwargs = self._parse_limits(
data=data, channel_index=channel_index, dynamic_range=dynamic_range, **kwargs
Expand All @@ -155,17 +179,19 @@ def _parse_plot_args(self, *args, **kwargs):
kwargs["colors"] = "k"
if "alpha" not in kwargs.keys():
kwargs["alpha"] = 0.5
if plot_type in ["pcolor", "pcolormesh", "contourf"]:
if plot_type in ["pcolor", "pcolormesh", "contourf", "imshow"]:
kwargs = self._parse_cmap(data=data, channel_index=channel_index, **kwargs)
else:
xi, yi, zi = args[:3]
if plot_type == "imshow":
kwargs = self._parse_limits(zi=args[0], **kwargs)
else:
kwargs = self._parse_limits(zi=args[2], **kwargs)
data = None
channel_index = 0
kwargs = self._parse_limits(zi=args[2], **kwargs)
if plot_type == "contourf":
if "levels" not in kwargs.keys():
kwargs["levels"] = np.linspace(kwargs["vmin"], kwargs["vmax"], 256)
if plot_type in ["pcolor", "pcolormesh", "contourf"]:
if plot_type in ["pcolor", "pcolormesh", "contourf", "imshow"]:
kwargs = self._parse_cmap(**kwargs)
# labels
self._apply_labels(
Expand Down Expand Up @@ -376,6 +402,65 @@ def pcolor(self, *args, **kwargs):
args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="pcolor")
return super().pcolor(*args, **kwargs)

def imshow(self, *args, **kwargs):
"""Create a pseudocolor plot of a 2-D array. The array is plotted
with uniform spacing. Quicker than pcolor, pcolormesh.
**Requires that the plotted axes are grid aligned (i.e. the `squeeze`
of each axis has ``ndim==1``).**
If a 3D or higher Data object is passed, a lower dimensional
channel can be plotted, provided the ``squeeze`` of the channel
has ``ndim==2``.
Defaults to ``aspect="auto"`` (pixels are stretched to fit the
subplot axes)
If `interpolation` method is not specified, defaults to either
"antialiased" (for large images) or "nearest" (for small arrays).
`extent` defaults to ensure that pixels are drawn bisecting point
positions.
Parameters
----------
data : 2D WrightTools.data.Data object
Data to plot.
channel : int or string (optional)
Channel index or name. Default is 0.
dynamic_range : boolean (optional)
Force plotting of all contours, overloading for major extent. Only applies to signed
data. Default is False.
autolabel : {'none', 'both', 'x', 'y'} (optional)
Parameterize application of labels directly from data object. Default is none.
xlabel : string (optional)
xlabel. Default is None.
ylabel : string (optional)
ylabel. Default is None.
**kwargs
matplotlib.axes.Axes.imshow__ optional keyword arguments.
__ https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html
Returns
-------
matplotlib.image.AxesImage
"""
xlim, ylim = super().get_xlim(), super().get_ylim()
old_signs = list(map(lambda x: (x[1] - x[0]) > 0, [xlim, ylim]))

args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="imshow")
out = super().imshow(*args, **kwargs)

# undo axis order if it was flipped
xlim, ylim = super().get_xlim(), super().get_ylim()
new_signs = list(map(lambda x: (x[1] - x[0]) > 0, [xlim, ylim]))
if old_signs[0] != new_signs[0]:
super().invert_xaxis()
if old_signs[1] != new_signs[1]:
super().invert_yaxis()
return out

def pcolormesh(self, *args, **kwargs):
"""Create a pseudocolor plot of a 2-D array.
Expand Down Expand Up @@ -540,3 +625,27 @@ def apply_rcparams(kind="fast"):
matplotlib.rcParams["font.size"] = 14
matplotlib.rcParams["legend.edgecolor"] = "grey"
matplotlib.rcParams["contour.negative_linestyle"] = "solid"


def _order_for_imshow(xi, yi):
"""
looks at x and y axis shape to determine order of zi axes
**requires orthogonal, 1D axes**
returns 2-ple: the transpose order to apply to zi
"""
sx = np.array(xi.shape)
sy = np.array(yi.shape)
# check that each axis is 1D (i.e. for ndim, number of axes with size 1 is >= ndim - 1 )
if (sx.prod() == xi.size) and (sy.prod() == yi.size):
# check that axes are orthogonal and orient z accordingly
# determine index of x and y axes
if (sx[0] == 1) and (sy[1] == 1):
# zi[y,x]
return (0, 1)
elif (sx[1] == 1) and (sy[0] == 1):
# zi[x,y]; imshow expects zi[rows, cols]
return (1, 0)
else:
raise TypeError(f"x and y must be orthogonal; shapes are: {xi.shape}, {yi.shape}")
else:
raise TypeError(f"Axes are not 1D: {xi.shape}, {yi.shape}")
35 changes: 22 additions & 13 deletions WrightTools/artists/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons
from types import SimpleNamespace

from ._helpers import create_figure, plot_colorbar, add_sideplot
from ._base import _order_for_imshow
from ._colors import colormaps
from ..exceptions import DimensionalityError
from .. import kit as wt_kit
Expand Down Expand Up @@ -40,14 +42,6 @@ def __call__(self, ax):
self.focus_axis = ax


# http://code.activestate.com/recipes/52308-the-simple-but-handy-collector-of-a-bunch-of-named/?in=user-97991
# used to keep track of vars useful to widgets
class Bunch(dict):
def __init__(self, **kw):
dict.__init__(self, kw)
self.__dict__ = self


def _at_dict(data, sliders, xaxis, yaxis):
return {
a.natural_name: (a[:].flat[int(sliders[a.natural_name].val)], a.units)
Expand Down Expand Up @@ -138,7 +132,9 @@ def norm(arr, signed, ignore_zero=True):
return arr


def interact2D(data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, verbose=True):
def interact2D(
data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, use_imshow=False, verbose=True
):
"""Interactive 2D plot of the dataset.
Side plots show x and y projections of the slice (shaded gray).
Left clicks on the main axes draw 1D slices on side plots at the coordinates selected.
Expand All @@ -157,6 +153,10 @@ def interact2D(data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, ver
Name or index of channel to plot. Default is 0.
local : boolean (optional)
Toggle plotting locally. Default is False.
use_imshow : boolean (optional)
If true, matplotlib imshow is used to render the 2D slice.
Can give better performance, but is only accurate for
uniform grids. Default is False.
verbose : boolean (optional)
Toggle talkback. Default is True.
"""
Expand All @@ -167,7 +167,7 @@ def interact2D(data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, ver
channel = get_channel(data, channel)
xaxis, yaxis = get_axes(data, [xaxis, yaxis])
cmap = get_colormap(channel)
current_state = Bunch()
current_state = SimpleNamespace()
# create figure
nsliders = data.ndim - 2
if nsliders < 0:
Expand Down Expand Up @@ -244,7 +244,9 @@ def interact2D(data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, ver
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
obj2D = ax0.pcolormesh(

gen_mesh = ax0.pcolormesh if not use_imshow else ax0.imshow
obj2D = gen_mesh(
current_state.dat,
cmap=cmap,
vmin=clim[0],
Expand Down Expand Up @@ -390,7 +392,7 @@ def update_local(index):
obj2D.set_clim(*clim)
fig.canvas.draw_idle()

def update_slider(info):
def update_slider(info, use_imshow=use_imshow):
current_state.dat.close()
current_state.dat = data.chop(
xaxis.natural_name,
Expand All @@ -406,7 +408,14 @@ def update_slider(info):
s.valtext.set_text(
gen_ticklabels(data.axes[data.axis_names.index(k)].points)[int(s.val)]
)
obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
if use_imshow:
transpose = _order_for_imshow(
current_state[xaxis.natural_name][:],
current_state[yaxis.natural_name][:],
)
obj2D.set_data(current_state.dat[channel.natural_name][:].transpose(transpose))
else:
obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
if clim[0] == clim[1]:
Expand Down
60 changes: 60 additions & 0 deletions tests/artists/test_imshow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#! /usr/bin/env python3


import WrightTools as wt
from WrightTools import datasets
from matplotlib import pyplot as plt
import numpy as np
import shutil
import os


def test_imshow_transform():
p = datasets.PyCMDS.d1_d2_000
data = wt.data.from_PyCMDS(p)

fig, gs = wt.artists.create_figure(cols=[1, 1])
ax0 = plt.subplot(gs[0])
data.transform("d1", "d2")
im1 = ax0.imshow(data)

ax1 = plt.subplot(gs[1])
data.transform("d2", "d1")
im2 = ax1.imshow(data)

assert np.all(im1.get_array() == im2.get_array().T)

data.close()


def test_imshow_approx_pcolormesh():
p = datasets.PyCMDS.w2_w1_000
data = wt.data.from_PyCMDS(p)

fig, gs = wt.artists.create_figure(cols=[1, 1])
ax0 = plt.subplot(gs[0])
mesh = ax0.pcolormesh(data)
ax1 = plt.subplot(gs[1])
image = ax1.imshow(data)

lim0 = ax0.get_xlim() + ax0.get_ylim()
lim1 = ax1.get_xlim() + ax1.get_ylim()
assert np.allclose(lim0, lim1, atol=1e-3, rtol=1), f"unequal axis limits: {lim0}, {lim1}"

bbox = mesh.get_datalim(ax0.transData)
meshbox = [bbox.x0, bbox.x1, bbox.y0, bbox.y1]
imagebox = image.get_extent()
imagebox = [*sorted(imagebox[:2]), *sorted(imagebox[2:])]
assert np.allclose(
meshbox, imagebox, atol=1e-3, rtol=1e-3
), f"unequal limits: mesh {meshbox} image {imagebox}"

assert np.isclose(mesh.norm.vmin, image.norm.vmin), "unequal norm.vmin"
assert np.isclose(mesh.norm.vmax, image.norm.vmax), "unequal norm.vmax"

data.close()


if __name__ == "__main__":
test_imshow_transform()
test_imshow_approx_pcolormesh()

0 comments on commit a9242fd

Please sign in to comment.