Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 4, 2014
1 parent 55b1c1a commit 4132f32
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
36 changes: 22 additions & 14 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import marshal
import struct
import sys
import types
import collections

from pyspark import cloudpickle
Expand Down Expand Up @@ -271,40 +272,47 @@ def dumps(self, obj):

# Hook namedtuple, make it picklable

old_namedtuple = collections.namedtuple
__cls = {}


def _restore(name, fields, value):
""" Restore an object of namedtuple"""
k = (name, fields)
cls = __cls.get(k)
if cls is None:
cls = namedtuple(name, fields)
cls = collections.namedtuple(name, fields)
__cls[k] = cls
return cls(*value)


def _hack_namedtuple(cls):
""" Make class generated by namedtuple picklable """
name = cls.__name__
fields = cls._fields
def __reduce__(self):
return (_restore, (name, fields, tuple(self)))
cls.__reduce__ = __reduce__
return cls

def namedtuple(name, fields, verbose=False, rename=False):
cls = old_namedtuple(name, fields, verbose, rename)
return _hack_namedtuple(cls)

namedtuple.__doc__ = old_namedtuple.__doc__
def _hijack_namedtuple():
""" Hack namedtuple() to make it picklable """
global _old_namedtuple # or it will put in closure

def _copy_func(f):
return types.FunctionType(f.func_code, f.func_globals, f.func_name,
f.func_defaults, f.func_closure)

def _hijack_namedtuple():
collections.namedtuple = namedtuple
_old_namedtuple = _copy_func(collections.namedtuple)

def namedtuple(name, fields, verbose=False, rename=False):
cls = _old_namedtuple(name, fields, verbose, rename)
return _hack_namedtuple(cls)

# replace all the reference to the new hacked one
import gc
for ref in gc.get_referrers(old_namedtuple):
if type(ref) is dict and ref.get("namedtuple") is old_namedtuple:
ref["namedtuple"] = namedtuple
# replace namedtuple with new one
collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.func_code = namedtuple.func_code

# hack the cls already generated by namedtuple
# those created in other module can be pickled as normal,
Expand All @@ -313,7 +321,7 @@ def _hijack_namedtuple():
if (type(o) is type and o.__base__ is tuple
and hasattr(o, "_fields")
and "__reduce__" not in o.__dict__):
_hack_namedtuple(o)
_hack_namedtuple(o) # hack inplace


_hijack_namedtuple()
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def test_huge_dataset(self):
m._cleanup()


class SerializationTestCase(unittest.TestCase):

def test_namedtuple(self):
from collections import namedtuple
from cPickle import dumps, loads
P = namedtuple("P", "x y")
p1 = P(1, 3)
p2 = loads(dumps(p1, 2))
self.assertEquals(p1, p2)


class PySparkTestCase(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 4132f32

Please sign in to comment.