Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make HeatMap more general #849

Merged
merged 22 commits into from Jan 9, 2017
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bec8024
Added is_nan utility
philippjfr Sep 5, 2016
339f988
Added functions to generate dense 2D aggregate from coordinates
philippjfr Sep 5, 2016
1d3d57e
Simplified HeatMap and allowed any number of value dimensions
philippjfr Sep 5, 2016
69a9793
Fixes for HeatMap implementations
philippjfr Sep 19, 2016
efd4bd9
Fixed missing imports
philippjfr Jan 8, 2017
17651f0
Added backward compatible raster property on HeatMap
philippjfr Jan 8, 2017
f3543e6
HeatMap now pre-computes gridded representation
philippjfr Jan 8, 2017
843387c
Fixes for HeatMap aggregation
philippjfr Jan 8, 2017
29f47c9
Made the get_2d_aggregate helper function general
philippjfr Jan 8, 2017
3f4b073
Fixed bug in HeatmapPlot
philippjfr Jan 8, 2017
d68485f
Added unit tests for HeatMap aggregation
philippjfr Jan 8, 2017
143c301
Retain global ordering of y-value dimensions
philippjfr Jan 8, 2017
0a91dce
Made categorical_aggregate2d an ElementOperation
philippjfr Jan 8, 2017
03cebf6
Small optimizations for categorical_aggregate2D
philippjfr Jan 8, 2017
844c1ad
Cleaned up HeatMap plotting classes
philippjfr Jan 8, 2017
fb4b207
Improved formatting for NaNs in HeatMap hover and annotations
philippjfr Jan 8, 2017
fcac23e
Removed depth on HeatMap
philippjfr Jan 8, 2017
d380d08
Removed unused variable
philippjfr Jan 8, 2017
dcae11f
Fixes for categorical_aggregate2d ordering
philippjfr Jan 8, 2017
9082070
Fixed and simplified one-to-one mapping function
philippjfr Jan 8, 2017
f5998f2
Added docstrings for graph utility functions
philippjfr Jan 9, 2017
050c4c7
Split categorical_aggregate2d into a few methods
philippjfr Jan 9, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions holoviews/core/util.py
Expand Up @@ -996,3 +996,13 @@ def dt64_to_dt(dt64):
"""
ts = (dt64 - np.datetime64('1970-01-01T00:00:00Z')) / np.timedelta64(1, 's')
return dt.datetime.utcfromtimestamp(ts)


def is_nan(x):
"""
Checks whether value is NaN on arbitrary types
"""
try:
return np.isnan(x)
except:
return False
93 changes: 12 additions & 81 deletions holoviews/element/raster.py
Expand Up @@ -14,7 +14,7 @@
from ..core.util import pd
from .chart import Curve
from .tabular import Table
from .util import compute_edges, toarray
from .util import compute_edges, toarray, get_2d_aggregate

try:
from ..core.data import PandasInterface
Expand Down Expand Up @@ -365,16 +365,14 @@ def dimension_values(self, dimension, expanded=True, flat=True):
return super(QuadMesh, self).dimension_values(idx)



class HeatMap(Dataset, Element2D):
"""
HeatMap is an atomic Element used to visualize two dimensional
parameter spaces. It supports sparse or non-linear spaces, dynamically
upsampling them to a dense representation, which can be visualized.

A HeatMap can be initialized with any dict or NdMapping type with
two-dimensional keys. Once instantiated the dense representation is
available via the .data property.
two-dimensional keys.
"""

group = param.String(default='HeatMap', constant=True)
Expand All @@ -383,85 +381,18 @@ class HeatMap(Dataset, Element2D):

vdims = param.List(default=[Dimension('z')])

def __init__(self, data, extents=None, **params):
depth = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have forgotten...what is this depth class attribute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may be wrong now, will have to look into it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't needed at all in the end, removed it.


def __init__(self, data, **params):
super(HeatMap, self).__init__(data, **params)
data, self.raster = self._compute_raster()
self.data = data.data
self.interface = data.interface
self.depth = 1
if extents is None:
(d1, d2) = self.raster.shape[:2]
self.extents = (0, 0, d2, d1)
else:
self.extents = extents


def _compute_raster(self):
if self.interface.gridded:
return self, np.flipud(self.dimension_values(2, flat=False))
d1keys = self.dimension_values(0, False)
d2keys = self.dimension_values(1, False)
coords = [(d1, d2, np.NaN) for d1 in d1keys for d2 in d2keys]
dtype = 'dataframe' if pd else 'dictionary'
dense_data = Dataset(coords, kdims=self.kdims, vdims=self.vdims, datatype=[dtype])
concat_data = self.interface.concatenate([dense_data, Dataset(self)], datatype=dtype)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'Mean of empty slice')
data = concat_data.aggregate(self.kdims, np.nanmean)
array = data.dimension_values(2).reshape(len(d1keys), len(d2keys))
return data, np.flipud(array.T)


def __setstate__(self, state):
if '_data' in state:
data = state['_data']
if isinstance(data, NdMapping):
items = [tuple(k)+((v,) if np.isscalar(v) else tuple(v))
for k, v in data.items()]
kdims = state['kdims'] if 'kdims' in state else self.kdims
vdims = state['vdims'] if 'vdims' in state else self.vdims
data = Dataset(items, kdims=kdims, vdims=vdims).data
elif isinstance(data, Dataset):
data = data.data
kdims = data.kdims
vdims = data.vdims
state['data'] = data
state['kdims'] = kdims
state['vdims'] = vdims
self.__dict__ = state

if isinstance(self.data, NdElement):
self.interface = NdElementInterface
elif isinstance(self.data, np.ndarray):
self.interface = ArrayInterface
elif util.is_dataframe(self.data):
self.interface = PandasInterface
elif isinstance(self.data, dict):
self.interface = DictInterface
self.depth = 1
data, self.raster = self._compute_raster()
self.interface = data.interface
self.data = data.data
if 'extents' not in state:
(d1, d2) = self.raster.shape[:2]
self.extents = (0, 0, d2, d1)

super(HeatMap, self).__setstate__(state)

def dense_keys(self):
d1keys = self.dimension_values(0, False)
d2keys = self.dimension_values(1, False)
return list(zip(*[(d1, d2) for d1 in d1keys for d2 in d2keys]))


def dframe(self, dense=False):
if dense:
keys1, keys2 = self.dense_keys()
dense_map = self.clone({(k1, k2): self._data.get((k1, k2), np.NaN)
for k1, k2 in product(keys1, keys2)})
return dense_map.dframe()
return super(HeatMap, self).dframe()
self.gridded = get_2d_aggregate(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see how much HeatMap has been simplified!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, it isn't immediately obvious that gridded is now a Dataset. Not sure I am necessarily recommending changing the name as gridded_dataset is awkward...


@property
def raster(self):
self.warning("The .raster attribute on HeatMap is deprecated, "
"the 2D aggregate is now computed dynamically "
"during plotting.")
return self.gridded.dimension_values(2, flat=False)


class Image(SheetCoordinateSystem, Raster):
Expand Down
66 changes: 66 additions & 0 deletions holoviews/element/util.py
@@ -1,10 +1,19 @@
import numpy as np

from ..core import Dataset, OrderedDict
from ..core.util import pd, is_nan

try:
import dask
except:
dask = None

try:
import xarray as xr
except:
xr = None


def toarray(v, index_value=False):
"""
Interface helper function to turn dask Arrays into numpy arrays as
Expand All @@ -30,3 +39,60 @@ def compute_edges(edges):
raise ValueError('Centered bins have to be of equal width.')
edges -= width/2.
return np.concatenate([edges, [edges[-1]+width]])


def reduce_fn(x):
"""
Aggregation function to get the first non-zero value.
"""
values = x.values if pd and isinstance(x, pd.Series) else x
for v in values:
if not is_nan(v):
return v
return np.NaN


def get_2d_aggregate(obj):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this would be better expressed as an operation? Then maybe it could have a minimal docstring example in the class docstring?

Generates a categorical 2D aggregate by inserting NaNs at all
cross-product locations that do not already have a value assigned.
Returns a 2D gridded Dataset object.
"""
if obj.interface.gridded:
return obj
elif obj.ndims > 2:
raise Exception("Cannot aggregate more than two dimensions")

dims = obj.dimensions(label=True)
xdim, ydim = dims[:2]
nvdims = len(dims) - 2
d1keys = obj.dimension_values(xdim, False)
d2keys = obj.dimension_values(ydim, False)

is_sorted = np.array_equal(np.sort(d1keys), d1keys)
if is_sorted:
grouped = obj.groupby(xdim, container_type=OrderedDict,
group_type=Dataset).values()
for group in grouped:
d2vals = group.dimension_values(ydim)
is_sorted &= np.array_equal(d2vals, np.sort(d2vals))

if is_sorted:
d1keys, d2keys = np.sort(d1keys), np.sort(d2keys)
coords = [(d1, d2) + (np.NaN,)*nvdims for d2 in d2keys for d1 in d1keys]

dtype = 'dataframe' if pd else 'dictionary'
dense_data = Dataset(coords, kdims=obj.kdims, vdims=obj.vdims, datatype=[dtype])
concat_data = obj.interface.concatenate([dense_data, Dataset(obj)], datatype=dtype)
agg = concat_data.reindex([xdim, ydim]).aggregate([xdim, ydim], reduce_fn)
shape = (len(d2keys), len(d1keys))
grid_data = {xdim: d1keys, ydim: d2keys}

for vdim in dims[2:]:
data = agg.dimension_values(vdim).reshape(shape)
data = np.ma.array(data, mask=np.logical_not(np.isfinite(data)))
grid_data[vdim] = data

grid_type = 'xarray' if xr else 'grid'
return agg.clone(grid_data, datatype=[grid_type])

25 changes: 18 additions & 7 deletions holoviews/plotting/bokeh/raster.py
@@ -1,7 +1,13 @@
import numpy as np
import param

from ...core.util import cartesian_product
from bokeh.models.mappers import LinearColorMapper
try:
from bokeh.models.mappers import LogColorMapper
except ImportError:
LogColorMapper = None

from ...core.util import cartesian_product, is_nan, unique_array
from ...element import Image, Raster, RGB
from ..renderer import SkipRendering
from ..util import map_colors
Expand Down Expand Up @@ -130,26 +136,31 @@ class HeatmapPlot(ColorbarPlot):
def _axes_props(self, plots, subplots, element, ranges):
dims = element.dimensions()
labels = self._get_axis_labels(dims)
xvals, yvals = [element.dimension_values(i, False)
agg = element.gridded
xvals, yvals = [unique_array(agg.dimension_values(i, False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought gridded Datasets have the 1D coordinate arrays available. Is the uniqueness being applied over the 2D set of samples or the 1D sequence?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point, no longer any need for the unique_array here.

for i in range(2)]
if self.invert_yaxis: yvals = yvals[::-1]
plot_ranges = {'x_range': [str(x) for x in xvals],
'y_range': [str(y) for y in yvals]}
return ('auto', 'auto'), labels, plot_ranges


def get_data(self, element, ranges=None, empty=False):
x, y, z = element.dimensions(label=True)
x, y, z = element.dimensions(label=True)[:3]
aggregate = element.gridded
style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if empty:
data = {x: [], y: [], z: [], 'color': []}
data = {x: [], y: [], z: []}
else:
zvals = np.rot90(element.raster, 3).flatten()
xvals, yvals = [[str(v) for v in element.dimension_values(i)]
zvals = aggregate.dimension_values(z)
xvals, yvals = [[str(v) for v in aggregate.dimension_values(i)]
for i in range(2)]
data = {x: xvals, y: yvals, z: zvals}

if 'hover' in self.tools+self.default_tools:
for vdim in element.vdims[1:]:
data[vdim.name] = ['' if is_nan(v) else v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if an empty string really suggests NaN. 'NaN' would be explicit but might look noisy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'm now using masked arrays to represent the data, in matplotlib the NaNs are therefore represented by -, which might be better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think - might be a good compromise.

for v in aggregate.dimension_values(vdim)]
return (data, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper},
'height': 1, 'width': 1})

Expand Down
40 changes: 22 additions & 18 deletions holoviews/plotting/mpl/raster.py
Expand Up @@ -8,7 +8,7 @@

from ...core import CompositeOverlay, Element
from ...core import traversal
from ...core.util import match_spec, max_range, unique_iterator
from ...core.util import match_spec, max_range, unique_iterator, unique_array
from ...element.raster import Image, Raster, RGB
from .element import ColorbarPlot, OverlayPlot
from .plot import MPLPlot, GridPlot
Expand Down Expand Up @@ -105,20 +105,19 @@ def _annotate_plot(self, ax, annotations):
handles = {}
for plot_coord, text in annotations.items():
handles[plot_coord] = ax.annotate(text, xy=plot_coord,
xycoords='axes fraction',
xycoords='data',
horizontalalignment='center',
verticalalignment='center')
return handles


def _annotate_values(self, element):
val_dim = element.vdims[0]
vals = np.rot90(element.raster, 3).flatten()
vals = element.dimension_values(2)
d1uniq, d2uniq = [element.dimension_values(i, False) for i in range(2)]
num_x, num_y = len(d1uniq), len(d2uniq)
xstep, ystep = 1.0/num_x, 1.0/num_y
xpos = np.linspace(xstep/2., 1.0-xstep/2., num_x)
ypos = np.linspace(ystep/2., 1.0-ystep/2., num_y)
xpos = np.linspace(0.5, num_x-0.5, num_x)
ypos = np.linspace(0.5, num_y-0.5, num_y)
plot_coords = product(xpos, ypos)
annotations = {}
for plot_coord, v in zip(plot_coords, vals):
Expand All @@ -130,21 +129,19 @@ def _annotate_values(self, element):

def _compute_ticks(self, element, ranges):
xdim, ydim = element.kdims
dim1_keys, dim2_keys = [element.dimension_values(i, False)
agg = element.gridded
dim1_keys, dim2_keys = [unique_array(agg.dimension_values(i, False))
for i in range(2)]
num_x, num_y = len(dim1_keys), len(dim2_keys)
x0, y0, x1, y1 = element.extents
xstep, ystep = ((x1-x0)/num_x, (y1-y0)/num_y)
xpos = np.linspace(x0+xstep/2., x1-xstep/2., num_x)
ypos = np.linspace(y0+ystep/2., y1-ystep/2., num_y)
xpos = np.linspace(.5, num_x-0.5, num_x)
ypos = np.linspace(.5, num_y-0.5, num_y)
xlabels = [xdim.pprint_value(k) for k in dim1_keys]
ylabels = [ydim.pprint_value(k) for k in dim2_keys]
return list(zip(xpos, xlabels)), list(zip(ypos, ylabels))


def init_artists(self, ax, plot_args, plot_kwargs):
l, r, b, t = plot_kwargs['extent']
ax.set_aspect(float(r - l)/(t-b))
ax.set_aspect(plot_kwargs.pop('aspect', 1))

handles = {}
annotations = plot_kwargs.pop('annotations', None)
Expand All @@ -156,18 +153,25 @@ def init_artists(self, ax, plot_args, plot_kwargs):

def get_data(self, element, ranges, style):
_, style, axis_kwargs = super(HeatMapPlot, self).get_data(element, ranges, style)
mask = np.logical_not(np.isfinite(element.raster))
data = np.ma.array(element.raster, mask=mask)
style['annotations'] = self._annotate_values(element)
aggregate = element.gridded
data = np.flipud(aggregate.dimension_values(2, flat=False))
shape = data.shape
cmap_name = style.pop('cmap', None)
cmap = copy.copy(plt.cm.get_cmap('gray' if cmap_name is None else cmap_name))
cmap.set_bad('w', 1.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to make this a plot option at some point instead of hard coding 'w'.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again good point, indeed we already expose this via clipping_colors, should hook that in here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also find it curious that you are using copy.copy on a colormap - which suggests you are mutating it. I guess set_bad must have side-effects which explains the copying...

style['cmap'] = cmap
style['aspect'] = shape[0]/shape[1]
style['extent'] = (0, shape[0], 0, shape[1])
style['annotations'] = self._annotate_values(aggregate)
return [data], style, axis_kwargs


def update_handles(self, key, axis, element, ranges, style):
im = self.handles['artist']
data, style, axis_kwargs = self.get_data(element, ranges, style)
l, r, b, t = style['extent']
im.set_data(data[0])
im.set_extent((l, r, b, t))
shape = data[0].shape
im.set_extent((0, shape[1], 0, shape[0]))
im.set_clim((style['vmin'], style['vmax']))
if 'norm' in style:
im.norm = style['norm']
Expand Down
4 changes: 2 additions & 2 deletions holoviews/plotting/util.py
Expand Up @@ -4,8 +4,8 @@
import param

from ..core import (HoloMap, DynamicMap, CompositeOverlay, Layout,
GridSpace, NdLayout, Store, Callable, Overlay)
from ..core.spaces import get_nested_streams
Overlay, GridSpace, NdLayout, Store, Dataset)
from ..core.spaces import get_nested_streams, Callable
from ..core.util import (match_spec, is_number, wrap_tuple, basestring,
get_overlay_spec, unique_iterator, safe_unicode)

Expand Down