diff --git a/Lib/pickle.py b/Lib/pickle.py index 1cee4d53fa4387..d767d85c1587d8 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -251,10 +251,7 @@ def get(self, i, pack=struct.pack): return GET + `i` + '\n' - def save(self, obj, - _builtin_type = (int, long, float, complex, str, unicode, - tuple, list, dict), - ): + def save(self, obj): # Check for persistent id (defined by a subclass) pid = self.persistent_id(obj) if pid: @@ -285,20 +282,24 @@ def save(self, obj, # Check copy_reg.dispatch_table reduce = dispatch_table.get(t) - if reduce: - rv = reduce(obj) - else: - # Check for __reduce__ method - reduce = getattr(obj, "__reduce__", None) - if not reduce: - # Check for instance of subclass of common built-in types - if self.proto >= 2 and isinstance(obj, _builtin_type): - assert t not in _builtin_type # Proper subclass + if not reduce: + # Check for a __reduce__ method. + # Subtle: get the unbound method from the class, so that + # protocol 2 can override the default __reduce__ that all + # classes inherit from object. This has the added + # advantage that the call always has the form reduce(obj) + reduce = getattr(t, "__reduce__", None) + if self.proto >= 2: + # Protocol 2 can do better than the default __reduce__ + if reduce is object.__reduce__: + reduce = None + if not reduce: self.save_newobj(obj) return + if not reduce: raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj)) - rv = reduce() + rv = reduce(obj) # Check for string returned by reduce(), meaning "save as global" if type(rv) is StringType: @@ -320,13 +321,6 @@ def save(self, obj, raise PicklingError("Tuple returned by %s must have " "exactly two or three elements" % reduce) - # XXX Temporary hack XXX - # Override the default __reduce__ for new-style class instances - if self.proto >= 2: - if func is _reconstructor: - self.save_newobj(obj) - return - # Save the reduce() output and finally memoize the object self.save_reduce(func, args, state) self.memoize(obj) @@ -375,10 +369,11 @@ def save_reduce(self, func, args, state=None): def save_newobj(self, obj): # Save a new-style class instance, using protocol 2. # XXX Much of this is still experimental. + assert self.proto >= 2 # This only works for protocol 2 t = type(obj) getnewargs = getattr(obj, "__getnewargs__", None) if getnewargs: - args = getnewargs() # This better not reference obj + args = getnewargs() # This bette not reference obj else: for cls in int, long, float, complex, str, unicode, tuple: if isinstance(obj, cls): @@ -409,10 +404,32 @@ def save_newobj(self, obj): getstate = getattr(obj, "__getstate__", None) if getstate: - state = getstate() - else: + try: + state = getstate() + except TypeError, err: + # XXX Catch generic exception caused by __slots__ + if str(err) != ("a class that defines __slots__ " + "without defining __getstate__ " + "cannot be pickled"): + print repr(str(err)) + raise # Not that specific exception + getstate = None + if not getstate: state = getattr(obj, "__dict__", None) - # XXX What about __slots__? + # If there are slots, the state becomes a tuple of two + # items: the first item the regular __dict__ or None, and + # the second a dict mapping slot names to slot values + names = _slotnames(t) + if names: + slots = {} + nil = [] + for name in names: + value = getattr(obj, name, nil) + if value is not nil: + slots[name] = value + if slots: + state = (state, slots) + if state is not None: save(state) write(BUILD) @@ -718,6 +735,24 @@ def save_global(self, obj, name = None): # Pickling helpers +def _slotnames(cls): + """Return a list of slot names for a given class. + + This needs to find slots defined by the class and its bases, so we + can't simply return the __slots__ attribute. We must walk down + the Method Resolution Order and concatenate the __slots__ of each + class found there. (This assumes classes don't modify their + __slots__ attribute to misrepresent their slots after the class is + defined.) + """ + if not hasattr(cls, "__slots__"): + return [] + names = [] + for c in cls.__mro__: + if "__slots__" in c.__dict__: + names += list(c.__dict__["__slots__"]) + return names + def _keep_alive(x, memo): """Keeps a reference to the object x in the memo. @@ -1152,22 +1187,29 @@ def load_setitems(self): def load_build(self): stack = self.stack - value = stack.pop() + state = stack.pop() inst = stack[-1] - try: - setstate = inst.__setstate__ - except AttributeError: + setstate = getattr(inst, "__setstate__", None) + if setstate: + setstate(state) + return + slotstate = None + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if state: try: - inst.__dict__.update(value) + inst.__dict__.update(state) except RuntimeError: - # XXX In restricted execution, the instance's __dict__ is not - # accessible. Use the old way of unpickling the instance - # variables. This is a semantic different when unpickling in - # restricted vs. unrestricted modes. - for k, v in value.items(): + # XXX In restricted execution, the instance's __dict__ + # is not accessible. Use the old way of unpickling + # the instance variables. This is a semantic + # difference when unpickling in restricted + # vs. unrestricted modes. + for k, v in state.items(): setattr(inst, k, v) - else: - setstate(value) + if slotstate: + for k, v in slotstate.items(): + setattr(inst, k, v) dispatch[BUILD] = load_build def load_mark(self):