Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added pickle test, still some outstanding problems (colorbar, legend …

…and geo axes known so far)
  • Loading branch information...
commit 1257bf927688bbaaf44b480f878fee2de5877be3 1 parent 3ec63ad
@pelson authored
View
1  lib/matplotlib/__init__.py
@@ -1068,6 +1068,7 @@ def tk_window_focus():
'matplotlib.tests.test_mathtext',
'matplotlib.tests.test_mlab',
'matplotlib.tests.test_patches',
+ 'matplotlib.tests.test_pickle',
'matplotlib.tests.test_rcparams',
'matplotlib.tests.test_simplification',
'matplotlib.tests.test_spines',
View
2  lib/matplotlib/artist.py
@@ -107,7 +107,7 @@ def __init__(self):
def __getstate__(self):
d = self.__dict__.copy()
# remove the unpicklable remove method, this will get re-added on load
- # if the artist lives on an axes.
+ # (by the axes) if the artist lives on an axes.
d['_remove_method'] = None
return d
View
16 lib/matplotlib/axes.py
@@ -154,9 +154,8 @@ def set_default_color_cycle(clist):
DeprecationWarning)
-class _process_plot_var_args:
+class _process_plot_var_args(object):
"""
-
Process variable length arguments to the plot command, so that
plot commands like the following are supported::
@@ -172,14 +171,13 @@ def __init__(self, axes, command='plot'):
self.command = command
self.set_color_cycle()
- def __getinitargs__(self):
- # note: __getinitargs__ only works for old-style classes
- # means that the color cycle will be lost.
- return (self.axes, self.command)
-
def __getstate__(self):
- # We don't need any state as we have the init args
- return False
+ # note: it is not possible to pickle a itertools.cycle instance
+ return {'axes': self.axes, 'command': self.command}
+
+ def __setstate__(self, state):
+ self.__dict__ = state.copy()
+ self.set_color_cycle()
def set_color_cycle(self, clist=None):
if clist is None:
View
19 lib/matplotlib/colorbar.py
@@ -185,6 +185,12 @@
docstring.interpd.update(colorbar_doc=colorbar_doc)
+def _set_ticks_on_axis_warn(*args, **kw):
+ # a top level function which gets put in at the axes'
+ # set_xticks set_yticks by _patch_ax
+ warnings.warn("Use the colorbar set_ticks() method instead.")
+
+
class ColorbarBase(cm.ScalarMappable):
'''
Draw a colorbar in an existing axes.
@@ -277,7 +283,7 @@ def __init__(self, ax, cmap=None,
# The rest is in a method so we can recalculate when clim changes.
self.config_axis()
self.draw_all()
-
+
def _extend_lower(self):
"""Returns whether the lower limit is open ended."""
return self.extend in ('both', 'min')
@@ -285,13 +291,12 @@ def _extend_lower(self):
def _extend_upper(self):
"""Returns whether the uper limit is open ended."""
return self.extend in ('both', 'max')
-
+
def _patch_ax(self):
- def _warn(*args, **kw):
- warnings.warn("Use the colorbar set_ticks() method instead.")
-
- self.ax.set_xticks = _warn
- self.ax.set_yticks = _warn
+ # bind some methods to the axes to warn users
+ # against using those methods.
+ self.ax.set_xticks = _set_ticks_on_axis_warn
+ self.ax.set_yticks = _set_ticks_on_axis_warn
def draw_all(self):
'''
View
48 lib/matplotlib/patches.py
@@ -1617,7 +1617,6 @@ def pprint_styles(klass):
"""
return _pprint_styles(klass._style_list)
-
@classmethod
def register(klass, name, style):
"""
@@ -1687,9 +1686,6 @@ def __init__(self):
"""
super(BoxStyle._Base, self).__init__()
-
-
-
def transmute(self, x0, y0, width, height, mutation_size):
"""
The transmute method is a very core of the
@@ -1701,8 +1697,6 @@ def transmute(self, x0, y0, width, height, mutation_size):
"""
raise NotImplementedError('Derived must override')
-
-
def __call__(self, x0, y0, width, height, mutation_size,
aspect_ratio=1.):
"""
@@ -1728,7 +1722,15 @@ def __call__(self, x0, y0, width, height, mutation_size,
else:
return self.transmute(x0, y0, width, height, mutation_size)
-
+ def __reduce__(self):
+ # because we have decided to nest thes classes, we need to
+ # add some more information to allow instance pickling.
+ import matplotlib.cbook as cbook
+ return (cbook._NestedClassGetter(),
+ (BoxStyle, self.__class__.__name__),
+ self.__dict__
+ )
+
class Square(_Base):
"""
@@ -2296,9 +2298,6 @@ def get_bbox(self):
return transforms.Bbox.from_bounds(self._x, self._y, self._width, self._height)
-
-
-
from matplotlib.bezier import split_bezier_intersecting_with_closedpath
from matplotlib.bezier import get_intersection, inside_circle, get_parallels
from matplotlib.bezier import make_wedged_bezier2
@@ -2359,7 +2358,7 @@ class _Base(object):
points. This base class defines a __call__ method, and few
helper methods.
"""
-
+
class SimpleEvent:
def __init__(self, xy):
self.x, self.y = xy
@@ -2401,7 +2400,6 @@ def insideB(xy_display):
return path
-
def _shrink(self, path, shrinkA, shrinkB):
"""
Shrink the path by fixed size (in points) with shrinkA and shrinkB
@@ -2441,6 +2439,15 @@ def __call__(self, posA, posB,
shrinked_path = self._shrink(clipped_path, shrinkA, shrinkB)
return shrinked_path
+
+ def __reduce__(self):
+ # because we have decided to nest thes classes, we need to
+ # add some more information to allow instance pickling.
+ import matplotlib.cbook as cbook
+ return (cbook._NestedClassGetter(),
+ (ConnectionStyle, self.__class__.__name__),
+ self.__dict__
+ )
class Arc3(_Base):
@@ -2771,7 +2778,6 @@ def connect(self, posA, posB):
{"AvailableConnectorstyles": _pprint_styles(_style_list)}
-
class ArrowStyle(_Style):
"""
:class:`ArrowStyle` is a container class which defines several
@@ -2867,8 +2873,6 @@ class and must be overriden in the subclasses. It receives
raise NotImplementedError('Derived must override')
-
-
def __call__(self, path, mutation_size, linewidth,
aspect_ratio=1.):
"""
@@ -2901,7 +2905,15 @@ def __call__(self, path, mutation_size, linewidth,
return path_mutated, fillable
else:
return self.transmute(path, mutation_size, linewidth)
-
+
+ def __reduce__(self):
+ # because we have decided to nest thes classes, we need to
+ # add some more information to allow instance pickling.
+ import matplotlib.cbook as cbook
+ return (cbook._NestedClassGetter(),
+ (ArrowStyle, self.__class__.__name__),
+ self.__dict__
+ )
class _Curve(_Base):
@@ -3048,7 +3060,6 @@ def __init__(self):
_style_list["-"] = Curve
-
class CurveA(_Curve):
"""
An arrow with a head at its begin point.
@@ -3087,7 +3098,6 @@ def __init__(self, head_length=.4, head_width=.2):
beginarrow=False, endarrow=True,
head_length=head_length, head_width=head_width )
- #_style_list["->"] = CurveB
_style_list["->"] = CurveB
@@ -3109,11 +3119,9 @@ def __init__(self, head_length=.4, head_width=.2):
beginarrow=True, endarrow=True,
head_length=head_length, head_width=head_width )
- #_style_list["<->"] = CurveAB
_style_list["<->"] = CurveAB
-
class CurveFilledA(_Curve):
"""
An arrow with filled triangle head at the begin.
View
BIN  lib/matplotlib/tests/baseline_images/test_pickle/multi_pickle.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
View
157 lib/matplotlib/tests/test_pickle.py
@@ -0,0 +1,157 @@
+from __future__ import print_function
+
+import numpy as np
+
+import matplotlib
+matplotlib.use('tkagg')
+
+from matplotlib.testing.decorators import cleanup, image_comparison
+import matplotlib.pyplot as plt
+
+from nose.tools import assert_equal, assert_not_equal
+
+# cpickle is faster, pickle gives better exceptions
+import cPickle as pickle
+import pickle
+
+from cStringIO import StringIO
+
+
+def recursive_pickle(obj, nested_info='top level object', memo=None):
+ """
+ Pickle the object's attributes recursively, storing a memo of the object
+ which have already been pickled.
+
+ If any pickling issues occur, a pickle.Pickle error will be raised with details.
+
+ This is not a completely general purpose routine, but will be useful for
+ debugging some pickle issues. HINT: cPickle is less verbose than Pickle
+
+
+ """
+ if memo is None:
+ memo = {}
+
+ if id(obj) in memo:
+ return
+
+ # put this object in the memo
+ memo[id(obj)] = obj
+
+ # start by pickling all of the object's attributes/contents
+
+ if isinstance(obj, list):
+ for i, item in enumerate(obj):
+ recursive_pickle(item, memo=memo, nested_info='list item #%s in (%s)' % (i, nested_info))
+ else:
+ if isinstance(obj, dict):
+ state = obj
+ elif hasattr(obj, '__getstate__'):
+ state = obj.__getstate__()
+ if not isinstance(state, dict):
+ state = {}
+ elif hasattr(obj, '__dict__'):
+ state = obj.__dict__
+ else:
+ state = {}
+
+ for key, value in state.iteritems():
+ recursive_pickle(value, memo=memo, nested_info='attribute "%s" in (%s)' % (key, nested_info))
+
+# print(id(obj), type(obj), nested_info)
+
+ # finally, try picking the object itself
+ try:
+ pickle.dump(obj, StringIO())#, pickle.HIGHEST_PROTOCOL)
+ except (pickle.PickleError, AssertionError), err:
+ print(pickle.PickleError('Pickling failed with nested info: [(%s) %s].'
+ '\nException: %s' % (type(obj),
+ nested_info,
+ err)))
+ # re-raise the exception for full traceback
+ raise
+
+
+@cleanup
+def test_simple():
+ fig = plt.figure()
+ # un-comment to debug
+ recursive_pickle(fig, 'figure')
+ pickle.dump(fig, StringIO(), pickle.HIGHEST_PROTOCOL)
+
+ ax = plt.subplot(121)
+# recursive_pickle(ax, 'ax')
+ pickle.dump(ax, StringIO(), pickle.HIGHEST_PROTOCOL)
+
+ ax = plt.axes(projection='polar')
+# recursive_pickle(ax, 'ax')
+ pickle.dump(ax, StringIO(), pickle.HIGHEST_PROTOCOL)
+
+# ax = plt.subplot(121, projection='hammer')
+# recursive_pickle(ax, 'figure')
+# pickle.dump(ax, StringIO(), pickle.HIGHEST_PROTOCOL)
+
+
+@image_comparison(baseline_images=['multi_pickle'],
+ extensions=['png'])
+def test_complete():
+ fig = plt.figure('Figure with a label?')
+
+ plt.suptitle('Can you fit any more in a figure?')
+
+ # make some arbitrary data
+ x, y = np.arange(8), np.arange(10)
+ data = u = v = np.linspace(0, 10, 80).reshape(10, 8)
+ v = np.sin(v * -0.6)
+
+ plt.subplot(3,3,1)
+ plt.plot(range(10))
+
+ plt.subplot(3, 3, 2)
+ plt.contourf(data, hatches=['//', 'ooo'])
+# plt.colorbar() # sadly, colorbar is currently failing. This might be an easy fix once
+ # its been identified what the problem is. (lambda functions in colorbar)
+
+ plt.subplot(3, 3, 3)
+ plt.pcolormesh(data)
+# cb = plt.colorbar()
+
+ plt.subplot(3, 3, 4)
+ plt.imshow(data)
+
+ plt.subplot(3, 3, 5)
+ plt.pcolor(data)
+
+ plt.subplot(3, 3, 6)
+ plt.streamplot(x, y, u, v)
+
+ plt.subplot(3, 3, 7)
+ plt.quiver(x, y, u, v)
+
+ plt.subplot(3, 3, 8)
+ plt.scatter(x, x**2, label='$x^2$')
+# plt.legend()
+
+ plt.subplot(3, 3, 9)
+ plt.errorbar(x, x * -0.5, xerr=0.2, yerr=0.4)
+
+
+ result_fh = StringIO()
+# recursive_pickle(fig, 'figure')
+ pickle.dump(fig, result_fh, pickle.HIGHEST_PROTOCOL)
+
+ plt.close('all')
+
+ # make doubly sure that there are no figures left
+ assert_equal(plt._pylab_helpers.Gcf.figs, {})
+
+ # wind back the fh and load in the figure
+ result_fh.seek(0)
+ fig = pickle.load(result_fh)
+
+ # make sure there is now a figure manager
+ assert_not_equal(plt._pylab_helpers.Gcf.figs, {})
+
+ assert_equal(fig.get_label(), 'Figure with a label?')
+
+
Please sign in to comment.
Something went wrong with that request. Please try again.