Skip to content

Commit

Permalink
Patch pickle-serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Dec 8, 2020
1 parent 20df74d commit 344377d
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 9 deletions.
9 changes: 6 additions & 3 deletions mars/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ cdef class Tokenizer:
if clz in self._handlers:
handler = self._handlers[object_type] = self._handlers[clz]
return handler(obj)
raise TypeError(f'Cannot generate token for {obj}, type: {object_type}')
try:
return cloudpickle.dumps(obj)
except:
raise TypeError(f'Cannot generate token for {obj}, type: {object_type}') from None


cdef inline list iterative_tokenize(object ob):
Expand Down Expand Up @@ -259,7 +262,7 @@ def tokenize_function(ob):


@lru_cache(500)
def tokenize_pickled(ob):
def tokenize_pickled_with_cache(ob):
return pickle.dumps(ob)


Expand Down Expand Up @@ -288,7 +291,7 @@ tokenize_handler.register(pd.DataFrame, tokenize_pandas_dataframe)
tokenize_handler.register(pd.Categorical, tokenize_pandas_categorical)
tokenize_handler.register(pd.CategoricalDtype, tokenize_categories_dtype)
tokenize_handler.register(pd.IntervalDtype, tokenize_interval_dtype)
tokenize_handler.register(tzinfo, tokenize_pickled)
tokenize_handler.register(tzinfo, tokenize_pickled_with_cache)
tokenize_handler.register(pd.arrays.DatetimeArray, tokenize_pandas_time_arrays)
tokenize_handler.register(pd.arrays.TimedeltaArray, tokenize_pandas_time_arrays)
tokenize_handler.register(pd.arrays.PeriodArray, tokenize_pandas_time_arrays)
Expand Down
1 change: 1 addition & 0 deletions mars/serialize/core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ cpdef enum ExtendType:
freq = 29
namedtuple = 30
regex = 31
pickled = 32767


cdef class Identity:
Expand Down
1 change: 1 addition & 0 deletions mars/serialize/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ cdef class ValueType:
freq = ExtendType.freq
namedtuple = ExtendType.namedtuple
regex = ExtendType.regex
pickled = ExtendType.pickled

identity = Identity()

Expand Down
22 changes: 21 additions & 1 deletion mars/serialize/jsonserializer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ cdef dict EXTEND_TYPE_TO_NAME = {
ValueType.complex128: 'complex128',
ValueType.namedtuple: 'namedtuple',
ValueType.regex: 'regex',
ValueType.pickled: 'pickled',
}


Expand Down Expand Up @@ -447,6 +448,20 @@ cdef class JsonSerializeProvider(Provider):
value = obj['value']
return to_offset(value)

cdef inline _serialize_pickled(self, value):
return {
'type': 'pickled',
'value': self._to_str(base64.b64encode(cloudpickle.dumps(value, protocol=self.pickle_protocol))),
}

cdef inline _deserialize_pickled(self, obj, list callbacks):
value = obj['value']
v = base64.b64decode(value)

if v is not None:
return cloudpickle.loads(v)
return None

cdef inline object _serialize_typed_value(self, value, tp, bint weak_ref=False):
if type(tp) not in (List, Tuple, Dict) and weak_ref:
# not iterable, and is weak ref
Expand Down Expand Up @@ -570,7 +585,10 @@ cdef class JsonSerializeProvider(Provider):
elif callable(value):
return self._serialize_function(value)
else:
raise TypeError(f'Unknown type to serialize: {type(value)}')
try:
return self._serialize_pickled(value)
except:
raise TypeError(f'Unknown type to serialize: {type(value)}') from None

cdef inline object _serialize_value(self, value, tp=None, bint weak_ref=False):
if tp is None:
Expand Down Expand Up @@ -702,6 +720,8 @@ cdef class JsonSerializeProvider(Provider):
return self._deserialize_namedtuple(obj, callbacks)
elif tp is ValueType.regex:
return self._deserialize_regex(obj, callbacks)
elif tp is ValueType.pickled:
return self._deserialize_pickled(obj, callbacks)
else:
raise TypeError(f'Unknown type to deserialize {obj["type"]}')

Expand Down
14 changes: 13 additions & 1 deletion mars/serialize/pbserializer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,13 @@ cdef class ProtobufSerializeProvider(Provider):
cdef inline object _get_freq(self, obj):
return to_offset(obj.freq)

cdef inline int _set_pickled(self, value, obj, tp=None) except -1:
obj.pickled = cloudpickle.dumps(value, protocol=self.pickle_protocol)
return 0

cdef inline object _get_pickled(self, obj):
return cloudpickle.loads(obj.pickled)

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
Expand Down Expand Up @@ -570,7 +577,10 @@ cdef class ProtobufSerializeProvider(Provider):
elif callable(value):
self._set_function(value, obj)
else:
raise TypeError(f'Unknown type to serialize: {type(value)}')
try:
self._set_pickled(value, obj)
except:
raise TypeError(f'Unknown type to serialize: {type(value)}') from None

cdef inline void _set_value(cls, value, obj, tp=None, bint weak_ref=False) except *:
if tp is None:
Expand Down Expand Up @@ -865,6 +875,8 @@ cdef class ProtobufSerializeProvider(Provider):
return ref(self._get_namedtuple(obj))
elif field == 'regex':
return ref(self._get_regex(obj))
elif field == 'pickled':
return ref(self._get_pickled(obj))
else:
raise TypeError('Unknown type to deserialize')

Expand Down
1 change: 1 addition & 0 deletions mars/serialize/protos/value.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ message Value {
string freq = 24; // pandas.tseries.offsets.Tick
bytes namedtuple = 25;
RegexValue regex = 26;
bytes pickled = 32767;
}

}
18 changes: 17 additions & 1 deletion mars/serialize/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@
nt = namedtuple('nt', 'a b')


class ClassToPickle:
def __init__(self, a):
self.a = a


class Node1(Serializable):
a = IdentityField('a', ValueType.string)
b1 = Int8Field('b1')
Expand All @@ -81,6 +86,7 @@ class Node1(Serializable):
e = BoolField('e')
f1 = KeyField('f1')
f2 = AnyField('f2')
f3 = AnyField('f3')
g = ReferenceField('g', 'Node2')
h = ListField('h')
i = ListField('i', ValueType.reference('self'))
Expand Down Expand Up @@ -233,6 +239,7 @@ def testPBSerialize(self, *_):
e=False,
f1=Node2Entity(node2),
f2=Node2Entity(node2),
f3=ClassToPickle(1285),
g=Node2(a=[['1', '2'], ['3', '4']]),
h=[[2, 3], node2, True, {1: node2}, np.datetime64('1066-10-13'),
np.timedelta64(1, 'D'), np.complex64(1+2j), np.complex128(2+3j),
Expand Down Expand Up @@ -282,6 +289,8 @@ def testPBSerialize(self, *_):
self.assertEqual(node3.value.f1.a, d_node3.value.f1.a)
self.assertIsNot(node3.value.f2, d_node3.value.f2)
self.assertEqual(node3.value.f2.a, d_node3.value.f2.a)
self.assertIsNot(node3.value.f3, d_node3.value.f3)
self.assertEqual(node3.value.f3.a, d_node3.value.f3.a)
self.assertIsNot(node3.value.g, d_node3.value.g)
self.assertEqual(node3.value.g.a, d_node3.value.g.a)
self.assertEqual(node3.value.h[0], d_node3.value.h[0])
Expand Down Expand Up @@ -331,6 +340,7 @@ def testJSONSerialize(self):
e=False,
f1=Node2Entity(node2),
f2=Node2Entity(node2),
f3=ClassToPickle(1285),
g=Node2(a=[['1', '2'], ['3', '4']]),
h=[[2, 3], node2, True, {1: node2}, np.datetime64('1066-10-13'),
np.timedelta64(1, 'D'), np.complex64(1+2j), np.complex128(2+3j),
Expand Down Expand Up @@ -381,6 +391,8 @@ def testJSONSerialize(self):
self.assertEqual(node3.value.f1.a, d_node3.value.f1.a)
self.assertIsNot(node3.value.f2, d_node3.value.f2)
self.assertEqual(node3.value.f2.a, d_node3.value.f2.a)
self.assertIsNot(node3.value.f3, d_node3.value.f3)
self.assertEqual(node3.value.f3.a, d_node3.value.f3.a)
self.assertIsNot(node3.value.g, d_node3.value.g)
self.assertEqual(node3.value.g.a, d_node3.value.g.a)
self.assertEqual(node3.value.h[0], d_node3.value.h[0])
Expand Down Expand Up @@ -682,7 +694,11 @@ def testAttributeAsDict(self):
self.assertIsInstance(d_node62.rl[0], Node6)

def testException(self):
node1 = Node1(h=[object()])
class Unserializable:
def __getstate__(self):
raise SystemError

node1 = Node1(h=[Unserializable()])

pbs = ProtobufSerializeProvider()

Expand Down
1 change: 1 addition & 0 deletions mars/serialize/tests/testser.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ message Node1Def {
bool e = 13;
Value f1 = 14;
Value f2 = 15;
Value f3 = 27;
Node2Def g = 16;
repeated Value h = 17;
repeated Node1Def i = 20;
Expand Down
9 changes: 6 additions & 3 deletions mars/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,16 @@ class TestEnum(Enum):
v = [df, df.index, df.columns, df['data'], pd.Categorical(list('ABCD'))]
self.assertEqual(utils.tokenize(v), utils.tokenize(copy.deepcopy(v)))

non_tokenizable_cls = type('non_tokenizable_cls', (object,), {})
class NonTokenizableCls:
def __getstate__(self):
raise SystemError

with self.assertRaises(TypeError):
utils.tokenize(non_tokenizable_cls())
utils.tokenize(NonTokenizableCls())

class CustomizedTokenize(object):
def __mars_tokenize__(self):
return id(type(self)), id(non_tokenizable_cls)
return id(type(self)), id(NonTokenizableCls)

self.assertEqual(utils.tokenize(CustomizedTokenize()),
utils.tokenize(CustomizedTokenize()))
Expand Down

0 comments on commit 344377d

Please sign in to comment.