Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

First pass pickle support for figures, transforms & artists.

  • Loading branch information...
commit 0db9429b5a3fd57b19b62e6eb23dcff1f12b5d03 1 parent dc535b4
@pelson authored
View
9 lib/matplotlib/artist.py
@@ -104,6 +104,13 @@ def __init__(self):
self.y_isdata = True # with y
self._snap = None
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ # remove the unpicklable remove method, this will get re-added on load
+ d.pop('_remove_method')
+# axes_artist_collections = ['lines', 'collections', 'tables', '']
+ return d
+
def remove(self):
"""
Remove the artist from the figure if possible. The effect
@@ -123,7 +130,7 @@ def remove(self):
# the _remove_method attribute directly. This would be a protected
# attribute if Python supported that sort of thing. The callback
# has one parameter, which is the child to be removed.
- if self._remove_method != None:
+ if self._remove_method is not None:
self._remove_method(self)
else:
raise NotImplementedError('cannot remove artist')
View
80 lib/matplotlib/axes.py
@@ -172,6 +172,13 @@ def __init__(self, axes, command='plot'):
self.command = command
self.set_color_cycle()
+ def __getinitargs__(self):
+ # means that the color cycle will be lost.
+ return (self.axes, self.command)
+
+ def __getstate__(self):
+ return False
+
def set_color_cycle(self, clist=None):
if clist is None:
clist = rcParams['axes.color_cycle']
@@ -332,7 +339,7 @@ def _grab_next_args(self, *args, **kwargs):
for seg in self._plot_args(remaining[:isplit], kwargs):
yield seg
remaining=remaining[isplit:]
-
+
class Axes(martist.Artist):
"""
@@ -352,9 +359,10 @@ class Axes(martist.Artist):
_shared_x_axes = cbook.Grouper()
_shared_y_axes = cbook.Grouper()
-
+
def __str__(self):
return "Axes(%g,%g;%gx%g)" % tuple(self._position.bounds)
+
def __init__(self, fig, rect,
axisbg = None, # defaults to rc axes.facecolor
frameon = True,
@@ -1423,7 +1431,9 @@ def add_artist(self, a):
self.artists.append(a)
self._set_artist_props(a)
a.set_clip_path(self.patch)
- a._remove_method = lambda h: self.artists.remove(h)
+ def remove_fn(artist):
+ self.artists.remove(artist)
+ a._remove_method = remove_fn #lambda h: self.artists.remove(h)
return a
def add_collection(self, collection, autolim=True):
@@ -1445,7 +1455,11 @@ def add_collection(self, collection, autolim=True):
if collection._paths and len(collection._paths):
self.update_datalim(collection.get_datalim(self.transData))
- collection._remove_method = lambda h: self.collections.remove(h)
+ # XXX back to start
+ def remove_fn(artist):
+ self.collections.remove(artist)
+
+ collection._remove_method = remove_fn #lambda h: self.collections.remove(h)
return collection
def add_line(self, line):
@@ -1463,7 +1477,10 @@ def add_line(self, line):
if not line.get_label():
line.set_label('_line%d'%len(self.lines))
self.lines.append(line)
- line._remove_method = lambda h: self.lines.remove(h)
+# def remove_fn(artist):
+# self.lines.remove(artist)
+# line._remove_method = remove_fn #lambda h: self.lines.remove(h)
+ line._remove_method = self.lines.remove
return line
def _update_line_limits(self, line):
@@ -1489,7 +1506,9 @@ def add_patch(self, p):
p.set_clip_path(self.patch)
self._update_patch_limits(p)
self.patches.append(p)
- p._remove_method = lambda h: self.patches.remove(h)
+ def remove_fn(artist):
+ self.patches.remove(artist)
+ p._remove_method = remove_fn #lambda h: self.patches.remove(h)
return p
def _update_patch_limits(self, patch):
@@ -1524,7 +1543,9 @@ def add_table(self, tab):
self._set_artist_props(tab)
self.tables.append(tab)
tab.set_clip_path(self.patch)
- tab._remove_method = lambda h: self.tables.remove(h)
+ def remove_fn(artist):
+ self.tables.remove(artist)
+ tab._remove_method = remove_fn #lambda h: self.tables.remove(h)
return tab
def add_container(self, container):
@@ -1538,7 +1559,9 @@ def add_container(self, container):
if not label:
container.set_label('_container%d'%len(self.containers))
self.containers.append(container)
- container.set_remove_method(lambda h: self.containers.remove(container))
+ def remove_fn(artist):
+ self.containers.remove(artist)
+ container.set_remove_method(remove_fn)
return container
@@ -1599,13 +1622,13 @@ def _process_unit_info(self, xdata=None, ydata=None, kwargs=None):
if xdata is not None:
# we only need to update if there is nothing set yet.
if not self.xaxis.have_units():
- self.xaxis.update_units(xdata)
+ self.xaxis.update_units(xdata)
#print '\tset from xdata', self.xaxis.units
if ydata is not None:
# we only need to update if there is nothing set yet.
if not self.yaxis.have_units():
- self.yaxis.update_units(ydata)
+ self.yaxis.update_units(ydata)
#print '\tset from ydata', self.yaxis.units
# process kwargs 2nd since these will override default units
@@ -3330,7 +3353,9 @@ def text(self, x, y, s, fontdict=None,
if fontdict is not None: t.update(fontdict)
t.update(kwargs)
self.texts.append(t)
- t._remove_method = lambda h: self.texts.remove(h)
+ def remove_fn(artist):
+ self.texts.remove(artist)
+ t._remove_method = remove_fn #lambda h: self.texts.remove(h)
#if t.get_clip_on(): t.set_clip_box(self.bbox)
@@ -3359,7 +3384,9 @@ def annotate(self, *args, **kwargs):
self._set_artist_props(a)
if kwargs.has_key('clip_on'): a.set_clip_path(self.patch)
self.texts.append(a)
- a._remove_method = lambda h: self.texts.remove(h)
+ def remove_fn(artist):
+ self.texts.remove(artist)
+ a._remove_method = remove_fn #lambda h: self.texts.remove(h)
return a
#### Lines and spans
@@ -7022,7 +7049,9 @@ def imshow(self, X, cmap=None, norm=None, aspect=None,
im.set_extent(im.get_extent())
self.images.append(im)
- im._remove_method = lambda h: self.images.remove(h)
+ def remove_fn(artist):
+ self.images.remove(artist)
+ im._remove_method = remove_fn #lambda h: self.images.remove(h)
return im
@@ -8770,7 +8799,15 @@ def __init__(self, fig, *args, **kwargs):
# _axes_class is set in the subplot_class_factory
self._axes_class.__init__(self, fig, self.figbox, **kwargs)
-
+ def __reduce__(self):
+ # get the first axes class which does not inherit from a subplotbase
+ axes_class = filter(lambda klass: (issubclass(klass, Axes) and
+ not issubclass(klass, SubplotBase)),
+ self.__class__.mro())[0]
+ r = [_PicklableSubplotClassConstructor(),
+ (axes_class,),
+ self.__getstate__()]
+ return tuple(r)
def get_geometry(self):
"""get the subplot geometry, eg 2,2,3"""
@@ -8852,6 +8889,21 @@ def subplot_class_factory(axes_class=None):
# This is provided for backward compatibility
Subplot = subplot_class_factory()
+
+class _PicklableSubplotClassConstructor(object):
+ """
+ This stub class exists to return the appropriate subplot
+ class when __call__-ed with an axes class. This is purely to
+ allow Picking of Axes and Subplots."""
+ def __call__(self, axes_class):
+ # create a dummy object instance
+ subplot_instance = _PicklableSubplotClassConstructor()
+ subplot_class = subplot_class_factory(axes_class)
+ # update the class to the desired subplot class
+ subplot_instance.__class__ = subplot_class
+ return subplot_instance
+
+
docstring.interpd.update(Axes=martist.kwdoc(Axes))
docstring.interpd.update(Subplot=martist.kwdoc(Axes))
View
1  lib/matplotlib/axis.py
@@ -595,7 +595,6 @@ class Ticker:
formatter = None
-
class Axis(artist.Artist):
"""
View
172 lib/matplotlib/cbook.py
@@ -152,6 +152,90 @@ def __call__(self, s):
if self.is_missing(s): return self.missingval
return int(s)
+
+class _BoundMethodProxy(object):
+ '''
+ Our own proxy object which enables weak references to bound and unbound
+ methods and arbitrary callables. Pulls information about the function,
+ class, and instance out of a bound method. Stores a weak reference to the
+ instance to support garbage collection.
+
+ @organization: IBM Corporation
+ @copyright: Copyright (c) 2005, 2006 IBM Corporation
+ @license: The BSD License
+
+ Minor bugfixes by Michael Droettboom
+ '''
+ def __init__(self, cb):
+ try:
+ try:
+ self.inst = ref(cb.im_self)
+ except TypeError:
+ self.inst = None
+ self.func = cb.im_func
+ self.klass = cb.im_class
+ except AttributeError:
+ self.inst = None
+ self.func = cb
+ self.klass = None
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ # de-weak reference inst
+ inst = d['inst']
+ if inst is not None:
+ d['inst'] = inst()
+ return d
+
+ def __setstate__(self, statedict):
+ self.__dict__ = statedict
+ inst = statedict['inst']
+ # turn inst back into a weakref
+ if inst is not None:
+ self.inst = ref(inst)
+
+ def __call__(self, *args, **kwargs):
+ '''
+ Proxy for a call to the weak referenced object. Take
+ arbitrary params to pass to the callable.
+
+ Raises `ReferenceError`: When the weak reference refers to
+ a dead object
+ '''
+ if self.inst is not None and self.inst() is None:
+ raise ReferenceError
+ elif self.inst is not None:
+ # build a new instance method with a strong reference to the instance
+ if sys.version_info[0] >= 3:
+ mtd = types.MethodType(self.func, self.inst())
+ else:
+ mtd = new.instancemethod(self.func, self.inst(), self.klass)
+ else:
+ # not a bound method, just return the func
+ mtd = self.func
+ # invoke the callable and return the result
+ return mtd(*args, **kwargs)
+
+ def __eq__(self, other):
+ '''
+ Compare the held function and instance with that held by
+ another proxy.
+ '''
+ try:
+ if self.inst is None:
+ return self.func == other.func and other.inst is None
+ else:
+ return self.func == other.func and self.inst() == other.inst()
+ except Exception:
+ return False
+
+ def __ne__(self, other):
+ '''
+ Inverse of __eq__.
+ '''
+ return not self.__eq__(other)
+
+
class CallbackRegistry:
"""
Handle registering and disconnecting for a set of signals and
@@ -190,72 +274,7 @@ def ondrink(x):
`"Mindtrove" blog
<http://mindtrove.info/articles/python-weak-references/>`_.
"""
- class BoundMethodProxy(object):
- '''
- Our own proxy object which enables weak references to bound and unbound
- methods and arbitrary callables. Pulls information about the function,
- class, and instance out of a bound method. Stores a weak reference to the
- instance to support garbage collection.
- @organization: IBM Corporation
- @copyright: Copyright (c) 2005, 2006 IBM Corporation
- @license: The BSD License
-
- Minor bugfixes by Michael Droettboom
- '''
- def __init__(self, cb):
- try:
- try:
- self.inst = ref(cb.im_self)
- except TypeError:
- self.inst = None
- self.func = cb.im_func
- self.klass = cb.im_class
- except AttributeError:
- self.inst = None
- self.func = cb
- self.klass = None
-
- def __call__(self, *args, **kwargs):
- '''
- Proxy for a call to the weak referenced object. Take
- arbitrary params to pass to the callable.
-
- Raises `ReferenceError`: When the weak reference refers to
- a dead object
- '''
- if self.inst is not None and self.inst() is None:
- raise ReferenceError
- elif self.inst is not None:
- # build a new instance method with a strong reference to the instance
- if sys.version_info[0] >= 3:
- mtd = types.MethodType(self.func, self.inst())
- else:
- mtd = new.instancemethod(self.func, self.inst(), self.klass)
- else:
- # not a bound method, just return the func
- mtd = self.func
- # invoke the callable and return the result
- return mtd(*args, **kwargs)
-
- def __eq__(self, other):
- '''
- Compare the held function and instance with that held by
- another proxy.
- '''
- try:
- if self.inst is None:
- return self.func == other.func and other.inst is None
- else:
- return self.func == other.func and self.inst() == other.inst()
- except Exception:
- return False
-
- def __ne__(self, other):
- '''
- Inverse of __eq__.
- '''
- return not self.__eq__(other)
def __init__(self, *args):
if len(args):
@@ -266,6 +285,13 @@ def __init__(self, *args):
self._cid = 0
self._func_cid_map = {}
+ def __getstate__(self):
+ # pickling of callbacks not yet handled/may never be handlable
+ return {'callbacks': {},
+ '_cid': self._cid,
+ '_func_cid_map': {},
+ }
+
def connect(self, s, func):
"""
register *func* to be called when a signal *s* is generated
@@ -279,7 +305,7 @@ def connect(self, s, func):
cid = self._cid
self._func_cid_map[s][func] = cid
self.callbacks.setdefault(s, dict())
- proxy = self.BoundMethodProxy(func)
+ proxy = _BoundMethodProxy(func)
self.callbacks[s][cid] = proxy
return cid
@@ -375,7 +401,7 @@ class silent_list(list):
"""
override repr when returning a list of matplotlib artists to
prevent long, meaningless output. This is meant to be used for a
- homogeneous list of a give type
+ homogeneous list of a given type
"""
def __init__(self, type, seq=None):
self.type = type
@@ -385,7 +411,16 @@ def __repr__(self):
return '<a list of %d %s objects>' % (len(self), self.type)
def __str__(self):
- return '<a list of %d %s objects>' % (len(self), self.type)
+ return repr(self)
+
+ def __getstate__(self):
+ # store a dictionary of this SilentList's state
+ return {'type': self.type, 'seq': self[:]}
+
+ def __setstate__(self, state):
+ self.type = state['type']
+ self.extend(state['seq'])
+
def strip_math(s):
'remove latex formatting from mathtext'
@@ -394,6 +429,7 @@ def strip_math(s):
for r in remove: s = s.replace(r,'')
return s
+
class Bunch:
"""
Often we want to just collect a bunch of stuff together, naming each
View
7 lib/matplotlib/contour.py
@@ -847,6 +847,13 @@ def __init__(self, ax, *args, **kwargs):
self.collections.append(col)
self.changed() # set the colors
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # the C object Cntr cannot currently be pickled. This isn't a big issue
+ # as it is not actually used once the contour has been calculated
+ state['Cntr'] = None
+ return state
+
def legend_elements(self, variable_name='x', str_format=str):
"""
Return a list of artist and labels suitable for passing through
View
26 lib/matplotlib/figure.py
@@ -33,6 +33,7 @@
import matplotlib.cbook as cbook
from matplotlib import docstring
+from matplotlib import __version__ as _mpl_version
from operator import itemgetter
import os.path
@@ -1131,6 +1132,31 @@ def _gci(self):
return im
return None
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # the axobservers cannot currently be pickled.
+ # Additionally, the canvas cannot currently be pickled, but this has
+ # the benefit of meaning that a figure can be detached from one canvas,
+ # and re-attached to another.
+ for attr_to_pop in ('_axobservers', 'show', 'canvas') :
+ state.pop(attr_to_pop)
+
+ # add version information to the state
+ state['__mpl_version__'] = _mpl_version
+
+ return state
+
+ def __setstate__(self, state):
+ version = state.pop('__mpl_version__')
+ if version != _mpl_version:
+ import warnings
+ warnings.warn("This figure was saved with matplotlib version %s "
+ "and is unlikely to function correctly." %
+ (version, ))
+ self.__dict__ = state
+ self._axobservers = []
+ self.canvas = None
+
def add_axobserver(self, func):
'whenever the axes state change, ``func(self)`` will be called'
self._axobservers.append(func)
View
10 lib/matplotlib/markers.py
@@ -113,6 +113,16 @@ def __init__(self, marker=None, fillstyle='full'):
self.set_marker(marker)
self.set_fillstyle(fillstyle)
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d.pop('_marker_function')
+ return d
+
+ def __setstate__(self, statedict):
+ self.__dict__ = statedict
+ self.set_marker(self._marker)
+ self._recache()
+
def _recache(self):
self._path = Path(np.empty((0,2)))
self._transform = IdentityTransform()
View
30 lib/matplotlib/ticker.py
@@ -133,33 +133,33 @@
from matplotlib import transforms as mtransforms
+class _DummyAxis:
+ def __init__(self):
+ self.dataLim = mtransforms.Bbox.unit()
+ self.viewLim = mtransforms.Bbox.unit()
-class TickHelper:
- axis = None
- class DummyAxis:
- def __init__(self):
- self.dataLim = mtransforms.Bbox.unit()
- self.viewLim = mtransforms.Bbox.unit()
+ def get_view_interval(self):
+ return self.viewLim.intervalx
- def get_view_interval(self):
- return self.viewLim.intervalx
+ def set_view_interval(self, vmin, vmax):
+ self.viewLim.intervalx = vmin, vmax
- def set_view_interval(self, vmin, vmax):
- self.viewLim.intervalx = vmin, vmax
+ def get_data_interval(self):
+ return self.dataLim.intervalx
- def get_data_interval(self):
- return self.dataLim.intervalx
+ def set_data_interval(self, vmin, vmax):
+ self.dataLim.intervalx = vmin, vmax
- def set_data_interval(self, vmin, vmax):
- self.dataLim.intervalx = vmin, vmax
+class TickHelper:
+ axis = None
def set_axis(self, axis):
self.axis = axis
def create_dummy_axis(self):
if self.axis is None:
- self.axis = self.DummyAxis()
+ self.axis = _DummyAxis()
def set_view_interval(self, vmin, vmax):
self.axis.set_view_interval(vmin, vmax)
View
20 lib/matplotlib/transforms.py
@@ -91,6 +91,17 @@ def __init__(self):
# computed for the first time.
self._invalid = 1
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ # turn the weakkey dictionary into a normal dictionary
+ d['_parents'] = dict(self._parents.iteritems())
+ return d
+
+ def __setstate__(self, data_dict):
+ self.__dict__ = data_dict
+ # turn the normal dictionary back into a WeakKeyDict
+ self._parents = WeakKeyDictionary(self._parents)
+
def __copy__(self, *args):
raise NotImplementedError(
"TransformNode instances can not be copied. " +
@@ -1275,12 +1286,19 @@ def __init__(self, child):
be replaced with :meth:`set`.
"""
assert isinstance(child, Transform)
-
Transform.__init__(self)
self.input_dims = child.input_dims
self.output_dims = child.output_dims
self._set(child)
self._invalid = 0
+
+ def __getstate__(self):
+ # only store the child
+ return {'child': self._child}
+
+ def __setstate__(self, state):
+ # re-initialise the TransformWrapper with the state's child
+ self.__init__(state['child'])
def __repr__(self):
return "TransformWrapper(%r)" % self._child
Please sign in to comment.
Something went wrong with that request. Please try again.