Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for batching NdOverlay plots #717

Merged
merged 10 commits into from Jun 16, 2016

Unified how plots declare plot methods across bokeh and matplotlib

  • Loading branch information...
Philipp Rudiger Philipp Rudiger
Philipp Rudiger authored and Philipp Rudiger committed Jun 16, 2016
commit 45c6b0ad594fa3a447ec20372b0f710e5aafd983
@@ -8,7 +8,7 @@
class TextPlot(ElementPlot):

style_opts = text_properties
_plot_method = 'text'
_plot_methods = dict(single='text')

def get_data(self, element, ranges=None, empty=False):
mapping = dict(x='x', y='y', text='text')
@@ -62,7 +62,7 @@ class SplinePlot(ElementPlot):
"""

style_opts = line_properties
_plot_method = 'bezier'
_plot_methods = dict(single='bezier')

def get_data(self, element, ranges=None, empty=False):
data_attrs = ['x0', 'y0', 'x1', 'y1',
@@ -46,9 +46,7 @@ class PointPlot(ElementPlot):
'unselected_color'] +
line_properties + fill_properties)

_plot_method = 'scatter'
_batched_plot_method = 'scatter'
_batched = True
_plot_methods = dict(single='scatter', batched='scatter')

def get_data(self, element, ranges=None, empty=False):
style = self.style[self.cyclic_index]
@@ -124,17 +122,16 @@ def _init_glyph(self, plot, mapping, properties):
renderer = plot.add_glyph(source, selected, selection_glyph=selected,
nonselection_glyph=unselected)
else:
renderer = getattr(plot, self._plot_method)(**dict(properties, **mapping))
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph


class CurvePlot(ElementPlot):

style_opts = ['color'] + line_properties
_plot_method = 'line'
_batched_plot_method = 'multi_line'
_plot_methods = dict(single='line', batched='multi_line')
_mapping = {p: p for p in ['xs', 'ys', 'color', 'line_alpha']}
_batched = True

def get_data(self, element, ranges=None, empty=False):
x = element.get_dimension(0).name
@@ -214,7 +211,7 @@ def get_data(self, element, ranges=None, empty=None):
class HistogramPlot(ElementPlot):

style_opts = ['color'] + line_properties + fill_properties
_plot_method = 'quad'
_plot_methods = dict(single='quad')

def get_data(self, element, ranges=None, empty=None):
mapping = dict(top='top', bottom=0, left='left', right='right')
@@ -137,11 +137,6 @@ class ElementPlot(BokehPlot, GenericElementPlot):
tick locations or bokeh Ticker object. If set to None
default bokeh ticking behavior is applied.""")

# A string corresponding to the glyph being drawn by the
# ElementPlot
_plot_method = None
_batched = False

# The plot objects to be updated on each frame
# Any entries should be existing keys in the handles
# instance attribute.
@@ -398,7 +393,7 @@ def _init_glyph(self, plot, mapping, properties):
Returns a Bokeh glyph object.
"""
properties = mpl_to_bokeh(properties)
plot_method = self._batched_plot_method if self.batched else self._plot_method
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph

@@ -15,7 +15,7 @@ class PathPlot(ElementPlot):
Whether to show legend for the plot.""")

style_opts = ['color'] + line_properties
_plot_method = 'multi_line'
_plot_methods = dict(single='multi_line')
_mapping = dict(xs='xs', ys='ys')

def get_data(self, element, ranges=None, empty=False):
@@ -27,9 +27,7 @@ def get_data(self, element, ranges=None, empty=False):
class PolygonPlot(PathPlot):

style_opts = ['color', 'cmap', 'palette'] + line_properties + fill_properties
_plot_method = 'patches'
_batched_plot_method = 'patches'
_batched = True
_plot_methods = dict(single='patches', batched='patches')

def get_data(self, element, ranges=None, empty=False):
xs = [] if empty else [path[:, 0] for path in element.data]
@@ -17,7 +17,7 @@ class RasterPlot(ElementPlot):
Whether to show legend for the plot.""")

style_opts = ['cmap']
_plot_method = 'image'
_plot_methods = dict(single='image')
_update_handles = ['color_mapper', 'source', 'glyph']

def __init__(self, *args, **kwargs):
@@ -74,7 +74,7 @@ def _update_glyph(self, glyph, properties, mapping):
class RGBPlot(RasterPlot):

style_opts = []
_plot_method = 'image_rgba'
_plot_methods = dict(single='image_rgba')

def get_data(self, element, ranges=None, empty=False):
data, mapping = super(RGBPlot, self).get_data(element, ranges, empty)
@@ -113,7 +113,7 @@ class HeatmapPlot(ElementPlot):
show_legend = param.Boolean(default=False, doc="""
Whether to show legend for the plot.""")

_plot_method = 'rect'
_plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties

def _axes_props(self, plots, subplots, element, ranges):
@@ -148,7 +148,7 @@ class QuadMeshPlot(ElementPlot):
show_legend = param.Boolean(default=False, doc="""
Whether to show legend for the plot.""")

_plot_method = 'rect'
_plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties

def get_data(self, element, ranges=None, empty=False):
Copy path View file
@@ -54,9 +54,7 @@ class CurvePlot(ChartPlot):

style_opts = ['alpha', 'color', 'visible', 'linewidth', 'linestyle', 'marker']

def init_artists(self, ax, plot_data, plot_kwargs):
return {'artist': ax.plot(*plot_data, **plot_kwargs)[0]}

_plot_methods = dict(single='plot')

def get_data(self, element, ranges, style):
xs = element.dimension_values(0)
@@ -87,6 +85,8 @@ class ErrorPlot(ChartPlot):
'markerfacecolor', 'markersize', 'solid_capstyle',
'solid_joinstyle', 'dashes', 'color']

_plot_methods = dict(single='errorbar')

def init_artists(self, ax, plot_data, plot_kwargs):
_, (bottoms, tops), verts = ax.errorbar(*plot_data, **plot_kwargs)
return {'bottoms': bottoms, 'tops': tops, 'verts': verts[0]}
@@ -143,6 +143,8 @@ class AreaPlot(ChartPlot):
'hatch', 'linestyle', 'joinstyle',
'fill', 'capstyle', 'interpolate']

_plot_methods = dict(single='fill_between')

def get_data(self, element, ranges, style):
xs = element.dimension_values(0)
ys = [element.dimension_values(vdim) for vdim in element.vdims]
@@ -455,10 +457,7 @@ class PointPlot(ChartPlot, ColorbarPlot):
'cmap', 'vmin', 'vmax']

_disabled_opts = ['size']

def init_artists(self, ax, plot_args, plot_kwargs):
return {'artist': ax.scatter(*plot_args, **plot_kwargs)}

_plot_methods = dict(single='scatter')

def get_data(self, element, ranges, style):
xs, ys = (element.dimension_values(i) for i in range(2))
@@ -546,6 +545,8 @@ class VectorFieldPlot(ColorbarPlot):
'scale', 'headlength', 'headaxislength', 'pivot',
'width','headwidth']

_plot_methods = dict(single='quiver')

def __init__(self, *args, **params):
super(VectorFieldPlot, self).__init__(*args, **params)
self._min_dist = self._get_map_info(self.hmap)
@@ -598,16 +599,12 @@ def get_data(self, element, ranges, style):
if 'pivot' not in style: style['pivot'] = 'mid'
if not self.arrow_heads:
style['headaxislength'] = 0
style.update(dict(scale=input_scale, angles=angles))
style.update(dict(scale=input_scale, angles=angles,
units='x', scale_units='x'))

return args, style, {}


def init_artists(self, ax, plot_args, plot_kwargs):
quiver = ax.quiver(*plot_args, units='x', scale_units='x', **plot_kwargs)
return {'artist': quiver}


def update_handles(self, key, axis, element, ranges, style):
args, style, axis_kwargs = self.get_data(element, ranges, style)

@@ -960,6 +957,8 @@ class BoxPlot(ChartPlot):
'whiskerprops', 'capprops', 'flierprops',
'medianprops', 'meanprops', 'meanline']

_plot_methods = dict(single='boxplot')

def get_extents(self, element, ranges):
return (np.NaN,)*4

@@ -987,11 +986,6 @@ def get_data(self, element, ranges, style):
element.vdims[0]]}


def init_artists(self, ax, plot_args, plot_kwargs):
boxplot = ax.boxplot(*plot_args, **plot_kwargs)
return {'artist': boxplot}


def teardown_handles(self):
for group in self.handles['artist'].values():
for v in group:
@@ -115,6 +115,8 @@ class Scatter3DPlot(Plot3D, PointPlot):
allow_None=True, doc="""
Index of the dimension from which the sizes will the drawn.""")

_plot_methods = dict(single='scatter')

def get_data(self, element, ranges, style):
xs, ys, zs = (element.dimension_values(i) for i in range(3))
self._compute_styles(element, ranges, style)
@@ -127,11 +129,6 @@ def get_data(self, element, ranges, style):
style['facecolors'] = color
return (xs, ys, zs), style, {}

def init_artists(self, ax, plot_data, plot_kwargs):
scatterplot = ax.scatter(*plot_data, **plot_kwargs)
ax.add_collection(scatterplot)
return {'artist': scatterplot}

def update_handles(self, key, axis, element, ranges, style):
artist = self.handles['artist']
artist._offsets3d, style, _ = self.get_data(element, ranges, style)
@@ -195,11 +192,10 @@ class TrisurfacePlot(Plot3D):

style_opts = ['cmap', 'color', 'shade', 'linewidth', 'edgecolor']

_plot_methods = dict(single='plot_trisurf')

def get_data(self, element, ranges, style):
dims = element.dimensions()
self._norm_kwargs(element, ranges, style, dims[2])
x, y, z = [element.dimension_values(d) for d in dims]
return (x, y, z), style, {}

def init_artists(self, ax, plot_data, plot_kwargs):
return {'artist': ax.plot_trisurf(*plot_data, **plot_kwargs)}
@@ -468,6 +468,18 @@ def initialize_plot(self, ranges=None):
return self._finalize_axis(self.keys[-1], ranges=ranges, **axis_kwargs)


def init_artists(self, ax, plot_args, plot_kwargs):
"""
Initializes the artist based on the plot method declared on
the plot.
"""
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
plot_fn = getattr(ax, plot_method)
artist = plot_fn(*plot_args, **plot_kwargs)
return {'artist': artist[0] if isinstance(artist, list) and
len(artist) == 1 else artist}


def update_handles(self, key, axis, element, ranges, style):
"""
Update the elements of the plot.
@@ -30,6 +30,7 @@ class RasterPlot(ColorbarPlot):
style_opts = ['alpha', 'cmap', 'interpolation', 'visible',
'filterrad', 'clims', 'norm']

_plot_methods = dict(single='imshow')

def __init__(self, *args, **kwargs):
super(RasterPlot, self).__init__(*args, **kwargs)
@@ -74,12 +75,6 @@ def get_data(self, element, ranges, style):

return [data], style, {'xticks': xticks, 'yticks': yticks}


def init_artists(self, ax, plot_args, plot_kwargs):
im = ax.imshow(*plot_args, **plot_kwargs)
return {'artist': im}


def update_handles(self, key, axis, element, ranges, style):
im = self.handles['artist']
data, style, axis_kwargs = self.get_data(element, ranges, style)
@@ -192,6 +187,8 @@ class QuadMeshPlot(ColorbarPlot):
style_opts = ['alpha', 'cmap', 'clim', 'edgecolors', 'norm', 'shading',
'linestyles', 'linewidths', 'hatch', 'visible']

_plot_methods = dict(single='pcolormesh')

def get_data(self, element, ranges, style):
data = np.ma.array(element.data[2],
mask=np.logical_not(np.isfinite(element.data[2])))
Copy path View file
@@ -483,8 +483,11 @@ class GenericElementPlot(DimensionedPlot):
apply_extents = param.Boolean(default=True, doc="""
Whether to apply extent overrides on the Elements""")

# Whether the plotting class supports batched plotting
_batched = False
# A dictionary mapping of the plot methods used to draw the
# glyphs corresponding to the ElementPlot, can support two
# keyword arguments a 'single' implementation to draw an individual
# plot and a 'batched' method to draw multiple Elements at once
_plot_methods = {}

def __init__(self, element, keys=None, ranges=None, dimensions=None,
batched=False, overlaid=0, cyclic_index=0, zorder=0, style=None,
@@ -509,7 +512,7 @@ def __init__(self, element, keys=None, ranges=None, dimensions=None,
super(GenericElementPlot, self).__init__(keys=keys, dimensions=dimensions,
dynamic=dynamic,
**dict(params, **plot_opts))
if self.batched and self._batched:
if self.batched:
self.ordering = util.layer_sort(self.hmap)
self.style = self.lookup_options(self.hmap.last.last, 'style').max_cycles(len(self.ordering))
else:
@@ -723,7 +726,7 @@ def _create_subplots(self, ranges):
batched = self.batched and type(self.hmap.last) is NdOverlay
if batched:
batchedplot = registry.get(type(self.hmap.last.last))
if (batched and batchedplot and batchedplot._batched and
if (batched and batchedplot and 'batched' in batchedplot._plot_methods and
(not self.show_legend or len(ordering) > self.legend_limit)):
self.batched = True
keys, vmaps = [()], [self.hmap]
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.