Skip to content

Commit

Permalink
Simplified HeatMap and allowed any number of value dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Sep 5, 2016
1 parent 9f6753d commit 9ee8034
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 102 deletions.
87 changes: 7 additions & 80 deletions holoviews/element/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,7 @@ class HeatMap(Dataset, Element2D):
upsampling them to a dense representation, which can be visualized.
A HeatMap can be initialized with any dict or NdMapping type with
two-dimensional keys. Once instantiated the dense representation is
available via the .data property.
two-dimensional keys.
"""

group = param.String(default='HeatMap', constant=True)
Expand All @@ -382,86 +381,14 @@ class HeatMap(Dataset, Element2D):

vdims = param.List(default=[Dimension('z')])

def __init__(self, data, extents=None, **params):
super(HeatMap, self).__init__(data, **params)
data, self.raster = self._compute_raster()
self.data = data.data
self.interface = data.interface
self.depth = 1
if extents is None:
(d1, d2) = self.raster.shape[:2]
self.extents = (0, 0, d2, d1)
else:
self.extents = extents


def _compute_raster(self):
if self.interface.gridded:
return self, np.flipud(self.dimension_values(2, flat=False))
d1keys = self.dimension_values(0, False)
d2keys = self.dimension_values(1, False)
coords = [(d1, d2, np.NaN) for d1 in d1keys for d2 in d2keys]
dtype = 'dataframe' if pd else 'dictionary'
dense_data = Dataset(coords, kdims=self.kdims, vdims=self.vdims, datatype=[dtype])
concat_data = self.interface.concatenate([dense_data, Dataset(self)], datatype=dtype)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'Mean of empty slice')
data = concat_data.aggregate(self.kdims, np.nanmean)
array = data.dimension_values(2).reshape(len(d1keys), len(d2keys))
return data, np.flipud(array.T)


def __setstate__(self, state):
if '_data' in state:
data = state['_data']
if isinstance(data, NdMapping):
items = [tuple(k)+((v,) if np.isscalar(v) else tuple(v))
for k, v in data.items()]
kdims = state['kdims'] if 'kdims' in state else self.kdims
vdims = state['vdims'] if 'vdims' in state else self.vdims
data = Dataset(items, kdims=kdims, vdims=vdims).data
elif isinstance(data, Dataset):
data = data.data
kdims = data.kdims
vdims = data.vdims
state['data'] = data
state['kdims'] = kdims
state['vdims'] = vdims
self.__dict__ = state

if isinstance(self.data, NdElement):
self.interface = NdElementInterface
elif isinstance(self.data, np.ndarray):
self.interface = ArrayInterface
elif util.is_dataframe(self.data):
self.interface = PandasInterface
elif isinstance(self.data, dict):
self.interface = DictInterface
self.depth = 1
data, self.raster = self._compute_raster()
self.interface = data.interface
self.data = data.data
if 'extents' not in state:
(d1, d2) = self.raster.shape[:2]
self.extents = (0, 0, d2, d1)

super(HeatMap, self).__setstate__(state)

def dense_keys(self):
d1keys = self.dimension_values(0, False)
d2keys = self.dimension_values(1, False)
return list(zip(*[(d1, d2) for d1 in d1keys for d2 in d2keys]))


def dframe(self, dense=False):
if dense:
keys1, keys2 = self.dense_keys()
dense_map = self.clone({(k1, k2): self._data.get((k1, k2), np.NaN)
for k1, k2 in product(keys1, keys2)})
return dense_map.dframe()
return super(HeatMap, self).dframe()
depth = 1


def __init__(self, data, **params):
super(HeatMap, self).__init__(data, **params)
shape = (len(self.dimension_values(1)), len(self.dimension_values(0)))
self.extents = (0., 0., shape[0], shape[1])


class Image(SheetCoordinateSystem, Raster):
"""
Expand Down
17 changes: 10 additions & 7 deletions holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
except ImportError:
LogColorMapper = None

from ...core.util import cartesian_product
from ...core.util import cartesian_product, is_nan
from ...element import Image, Raster, RGB
from ..renderer import SkipRendering
from ..util import map_colors
from ..util import map_colors, get_2d_aggregate
from .element import ElementPlot, line_properties, fill_properties
from .util import mplcmap_to_palette, get_cmap, hsv_to_rgb

Expand Down Expand Up @@ -148,21 +148,24 @@ def _axes_props(self, plots, subplots, element, ranges):
'y_range': [str(y) for y in yvals]}
return ('auto', 'auto'), labels, plot_ranges


def get_data(self, element, ranges=None, empty=False):
x, y, z = element.dimensions(label=True)
x, y, z = element.dimensions(label=True)[:3]
aggregate = get_2d_aggregate(element)
if empty:
data = {x: [], y: [], z: [], 'color': []}
else:
style = self.style[self.cyclic_index]
cmap = style.get('palette', style.get('cmap', None))
cmap = get_cmap(cmap)
zvals = np.rot90(element.raster, 3).flatten()
zvals = aggregate.dimension_values(z)
colors = map_colors(zvals, ranges[z], cmap)
xvals, yvals = [[str(v) for v in element.dimension_values(i)]
xvals, yvals = [[str(v) for v in aggregate.dimension_values(i)]
for i in range(2)]
data = {x: xvals, y: yvals, z: zvals, 'color': colors}

if 'hover' in self.tools:
for vdim in element.vdims[1:]:
data[vdim.name] = ['' if is_nan(v) else v
for v in aggregate.dimension_values(vdim)]
return (data, {'x': x, 'y': y, 'fill_color': 'color', 'height': 1, 'width': 1})


Expand Down
30 changes: 15 additions & 15 deletions holoviews/plotting/mpl/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ...element.raster import Image, Raster, RGB
from .element import ColorbarPlot, OverlayPlot
from .plot import MPLPlot, GridPlot

from ..util import get_2d_aggregate

class RasterPlot(ColorbarPlot):

Expand Down Expand Up @@ -97,20 +97,20 @@ def _annotate_plot(self, ax, annotations):
handles = {}
for plot_coord, text in annotations.items():
handles[plot_coord] = ax.annotate(text, xy=plot_coord,
xycoords='axes fraction',
xycoords='data',
horizontalalignment='center',
verticalalignment='center')
return handles


def _annotate_values(self, element):
val_dim = element.vdims[0]
vals = np.rot90(element.raster, 3).flatten()
aggregate = get_2d_aggregate(element)
vals = aggregate.dimension_values(2)
d1uniq, d2uniq = [element.dimension_values(i, False) for i in range(2)]
num_x, num_y = len(d1uniq), len(d2uniq)
xstep, ystep = 1.0/num_x, 1.0/num_y
xpos = np.linspace(xstep/2., 1.0-xstep/2., num_x)
ypos = np.linspace(ystep/2., 1.0-ystep/2., num_y)
xpos = np.linspace(0.5, num_x-0.5, num_x)
ypos = np.linspace(0.5, num_y-0.5, num_y)
plot_coords = product(xpos, ypos)
annotations = {}
for plot_coord, v in zip(plot_coords, vals):
Expand All @@ -125,18 +125,15 @@ def _compute_ticks(self, element, ranges):
dim1_keys, dim2_keys = [element.dimension_values(i, False)
for i in range(2)]
num_x, num_y = len(dim1_keys), len(dim2_keys)
x0, y0, x1, y1 = element.extents
xstep, ystep = ((x1-x0)/num_x, (y1-y0)/num_y)
xpos = np.linspace(x0+xstep/2., x1-xstep/2., num_x)
ypos = np.linspace(y0+ystep/2., y1-ystep/2., num_y)
xpos = np.linspace(.5, num_x-0.5, num_x)
ypos = np.linspace(.5, num_y-0.5, num_y)
xlabels = [xdim.pprint_value(k) for k in dim1_keys]
ylabels = [ydim.pprint_value(k) for k in dim2_keys]
return list(zip(xpos, xlabels)), list(zip(ypos, ylabels))


def init_artists(self, ax, plot_args, plot_kwargs):
l, r, b, t = plot_kwargs['extent']
ax.set_aspect(float(r - l)/(t-b))
ax.set_aspect(plot_kwargs.pop('aspect', 1))

handles = {}
annotations = plot_kwargs.pop('annotations', None)
Expand All @@ -148,20 +145,23 @@ def init_artists(self, ax, plot_args, plot_kwargs):

def get_data(self, element, ranges, style):
_, style, axis_kwargs = super(HeatMapPlot, self).get_data(element, ranges, style)
data = element.raster
shape = tuple(len(element.dimension_values(i)) for i in range(2))
aggregate = get_2d_aggregate(element)
data = np.flipud(aggregate.dimension_values(2).reshape(shape[::-1]))
data = np.ma.array(data, mask=np.logical_not(np.isfinite(data)))
cmap_name = style.pop('cmap', None)
cmap = copy.copy(plt.cm.get_cmap('gray' if cmap_name is None else cmap_name))
cmap.set_bad('w', 1.)
style['cmap'] = cmap
style['annotations'] = self._annotate_values(element)
style['aspect'] = shape[0]/shape[1]
style['extent'] = (0, shape[0], 0, shape[1])
style['annotations'] = self._annotate_values(aggregate)
return [data], style, axis_kwargs


def update_handles(self, key, axis, element, ranges, style):
im = self.handles['artist']
data, style, axis_kwargs = self.get_data(element, ranges, style)
l, r, b, t = style['extent']
im.set_data(data[0])
im.set_extent((l, r, b, t))
im.set_clim((style['vmin'], style['vmax']))
Expand Down

0 comments on commit 9ee8034

Please sign in to comment.