In [1]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
from mpl_interactions import heatmap_slicer

## Comparing heatmaps

Sometimes I find myself wanting to compare horizontal or vertical slices across two different heatmaps with the same shape. The function `heatmap_slicer` makes this easy and should work for any number of heatmaps from 1 to many (likely not all the way $\inf$ though). 

The most important options to play with are `slices = {'both', 'vertical', 'horizontal'}`, and `interaction_type = {'move', 'click'}`

In [2]:
x = np.linspace(0,np.pi,100)
y = np.linspace(0,10,200)
X,Y = np.meshgrid(x,y)
data1 = np.sin(X)+np.exp(np.cos(Y))
data2 = np.cos(X)+np.exp(np.sin(Y))
fig,axes = heatmap_slicer(x,y,(data1,data2),slices='both',heatmap_names=('dataset 1','dataset 2'),labels=('Some wild X variable','Y axis'),interaction_type='move')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  ax.pcolormesh(X,Y,heatmaps[i])
  ax.pcolormesh(X,Y,heatmaps[i])


In [4]:
from matplotlib.image import PcolorImage

In [None]:
PcolorImage()

In [3]:
fig, ax = plt.subplots()
ax.convert_xunits??

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[0;31mSignature:[0m [0max[0m[0;34m.[0m[0mconvert_xunits[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mconvert_xunits[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34m"""[0m
[0;34m        Convert *x* using the unit type of the xaxis.[0m
[0;34m[0m
[0;34m        If the artist is not in contained in an Axes or if the xaxis does not[0m
[0;34m        have units, *x* itself is returned.[0m
[0;34m        """[0m[0;34m[0m
[0;34m[0m        [0max[0m [0;34m=[0m [0mgetattr[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m'axes'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0max[0m [0;32mis[0m [0;32mNone[0m [0;32mor[0m [0max[0m[0;34m.[0m[0mxaxis[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0;32mreturn[0m [0mx[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m 

In [18]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
from mpl_interactions import heatmap_slicer
X = np.linspace(0, 10, 11)
Y = np.linspace(0, 5, 11)
Z = np.random.randn(10,10 )
plt.figure()
plt.pcolormesh(X, Y, Z) # works
heatmap_slicer(X[:-1],Y[:-1],Z) # breaks

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  ax.pcolormesh(X,Y,heatmaps[i])


(<Figure size 1800x900 with 2 Axes>,
 array([<AxesSubplot:title={'center':'heatmap_0'}, xlabel='X', ylabel='Y'>,
        <AxesSubplot:title={'center':'Horizontal'}>], dtype=object))

In [17]:
(X[1:] + X[:-1]) / 2

array([2.26013860e-03, 6.78041580e-03, 1.13006930e-02, 1.58209702e-02,
       2.03412474e-02, 2.48615246e-02, 2.93818018e-02, 3.39020790e-02,
       3.84223562e-02, 4.29426334e-02, 4.74629106e-02, 5.19831878e-02,
       5.65034650e-02, 6.10237422e-02, 6.55440194e-02, 7.00642966e-02,
       7.45845738e-02, 7.91048510e-02, 8.36251282e-02, 8.81454054e-02,
       9.26656826e-02, 9.71859598e-02, 1.01706237e-01, 1.06226514e-01,
       1.10746791e-01, 1.15267069e-01, 1.19787346e-01, 1.24307623e-01,
       1.28827900e-01, 1.33348177e-01, 1.37868455e-01, 1.42388732e-01,
       1.46909009e-01, 1.51429286e-01, 1.55949563e-01, 1.60469841e-01,
       1.64990118e-01, 1.69510395e-01, 1.74030672e-01, 1.78550949e-01,
       1.83071227e-01, 1.87591504e-01, 1.92111781e-01, 1.96632058e-01,
       2.01152335e-01, 2.05672613e-01, 2.10192890e-01, 2.14713167e-01,
       2.19233444e-01, 2.23753721e-01, 2.28273999e-01, 2.32794276e-01,
       2.37314553e-01, 2.41834830e-01, 2.46355107e-01, 2.50875385e-01,
      

In [8]:
# make these smaller to increase the resolution
dx, dy = 0.15, 0.05

# generate 2 2d grids for the x & y bounds
y, x = np.mgrid[-3:3+dy:dy, -3:3+dx:dx]
z = (1 - x/2 + x**5 + y**3) * np.exp(-x**2 - y**2)
# x and y are bounds, so z should be the value *inside* those bounds.
# Therefore, remove the last value from the z array.
z = z[:-1, :-1]
z_min, z_max = -abs(z).max(), abs(z).max()

fig, axs = plt.subplots(2, 2)

ax = axs[0, 0]
c = ax.pcolor(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('pcolor')
fig.colorbar(c, ax=ax)

ax = axs[0, 1]
c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('pcolormesh')
fig.colorbar(c, ax=ax)

ax = axs[1, 0]
im = ax.imshow(z, cmap='RdBu', vmin=z_min, vmax=z_max,
              extent=[x.min(), x.max(), y.min(), y.max()],
              interpolation='nearest', origin='lower', aspect='auto')
ax.set_title('image (nearest, aspect="auto")')
fig.colorbar(c, ax=ax)

ax = axs[1, 1]
c = ax.pcolorfast(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('pcolorfast')
fig.colorbar(c, ax=ax)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
im

<matplotlib.image.AxesImage at 0x7fc582f69090>

In [10]:
nrows = 3
ncols = 5
Z = np.arange(nrows * ncols).reshape(nrows, ncols)
x = np.arange(ncols + 1)
y = np.arange(nrows + 1)

fig, ax = plt.subplots()
ax.pcolormesh(x, y, Z, shading='flat', vmin=Z.min(), vmax=Z.max())


def _annotate(ax, x, y, title):
    # this all gets repeated below:
    X, Y = np.meshgrid(x, y)
    ax.plot(X.flat, Y.flat, 'o', color='m')
    ax.set_xlim(-0.7, 5.2)
    ax.set_ylim(-0.7, 3.2)
    ax.set_title(title)

_annotate(ax, x, y, "shading='flat'")


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
x = np.arange(ncols)  # note *not* ncols + 1 as before
y = np.arange(nrows)
fig, ax = plt.subplots()
ax.pcolormesh(x, y, Z, shading='flat', vmin=Z.min(), vmax=Z.max())
_annotate(ax, x, y, "shading='flat': X, Y, C same shape")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  after removing the cwd from sys.path.


In [12]:
fig, ax = plt.subplots()
ax.pcolormesh(x, y, Z, shading='nearest', vmin=Z.min(), vmax=Z.max())
_annotate(ax, x, y, "shading='nearest'")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Possible cases

**mpl < 3.3**

- same shape
- shape + 1

**mpl >= 3.3**

