diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index c478c1bcd66fc..02822b8a8f992 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -65,6 +65,7 @@ import marshal import struct import sys +import types import collections from pyspark import cloudpickle @@ -271,18 +272,21 @@ def dumps(self, obj): # Hook namedtuple, make it picklable -old_namedtuple = collections.namedtuple __cls = {} + def _restore(name, fields, value): + """ Restore an object of namedtuple""" k = (name, fields) cls = __cls.get(k) if cls is None: - cls = namedtuple(name, fields) + cls = collections.namedtuple(name, fields) __cls[k] = cls return cls(*value) + def _hack_namedtuple(cls): + """ Make class generated by namedtuple picklable """ name = cls.__name__ fields = cls._fields def __reduce__(self): @@ -290,21 +294,25 @@ def __reduce__(self): cls.__reduce__ = __reduce__ return cls -def namedtuple(name, fields, verbose=False, rename=False): - cls = old_namedtuple(name, fields, verbose, rename) - return _hack_namedtuple(cls) -namedtuple.__doc__ = old_namedtuple.__doc__ +def _hijack_namedtuple(): + """ Hack namedtuple() to make it picklable """ + global _old_namedtuple # or it will put in closure + def _copy_func(f): + return types.FunctionType(f.func_code, f.func_globals, f.func_name, + f.func_defaults, f.func_closure) -def _hijack_namedtuple(): - collections.namedtuple = namedtuple + _old_namedtuple = _copy_func(collections.namedtuple) + + def namedtuple(name, fields, verbose=False, rename=False): + cls = _old_namedtuple(name, fields, verbose, rename) + return _hack_namedtuple(cls) - # replace all the reference to the new hacked one - import gc - for ref in gc.get_referrers(old_namedtuple): - if type(ref) is dict and ref.get("namedtuple") is old_namedtuple: - ref["namedtuple"] = namedtuple + # replace namedtuple with new one + collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.func_code = namedtuple.func_code # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, @@ -313,7 +321,7 @@ def _hijack_namedtuple(): if (type(o) is type and o.__base__ is tuple and hasattr(o, "_fields") and "__reduce__" not in o.__dict__): - _hack_namedtuple(o) + _hack_namedtuple(o) # hack inplace _hijack_namedtuple() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a2d2c0dd1d0b3..ce981881d2e6b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -105,6 +105,17 @@ def test_huge_dataset(self): m._cleanup() +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from cPickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEquals(p1, p2) + + class PySparkTestCase(unittest.TestCase): def setUp(self):