From 344377d735b2bbcd37face8bee056e86c0248828 Mon Sep 17 00:00:00 2001 From: "wenjun.swj" Date: Tue, 8 Dec 2020 15:21:08 +0800 Subject: [PATCH] Patch pickle-serialize --- mars/_utils.pyx | 9 ++++++--- mars/serialize/core.pxd | 1 + mars/serialize/core.pyx | 1 + mars/serialize/jsonserializer.pyx | 22 +++++++++++++++++++++- mars/serialize/pbserializer.pyx | 14 +++++++++++++- mars/serialize/protos/value.proto | 1 + mars/serialize/tests/test_serialize.py | 18 +++++++++++++++++- mars/serialize/tests/testser.proto | 1 + mars/tests/test_utils.py | 9 ++++++--- 9 files changed, 67 insertions(+), 9 deletions(-) diff --git a/mars/_utils.pyx b/mars/_utils.pyx index 9ed1e2f1d2..431253a3cc 100644 --- a/mars/_utils.pyx +++ b/mars/_utils.pyx @@ -117,7 +117,10 @@ cdef class Tokenizer: if clz in self._handlers: handler = self._handlers[object_type] = self._handlers[clz] return handler(obj) - raise TypeError(f'Cannot generate token for {obj}, type: {object_type}') + try: + return cloudpickle.dumps(obj) + except: + raise TypeError(f'Cannot generate token for {obj}, type: {object_type}') from None cdef inline list iterative_tokenize(object ob): @@ -259,7 +262,7 @@ def tokenize_function(ob): @lru_cache(500) -def tokenize_pickled(ob): +def tokenize_pickled_with_cache(ob): return pickle.dumps(ob) @@ -288,7 +291,7 @@ tokenize_handler.register(pd.DataFrame, tokenize_pandas_dataframe) tokenize_handler.register(pd.Categorical, tokenize_pandas_categorical) tokenize_handler.register(pd.CategoricalDtype, tokenize_categories_dtype) tokenize_handler.register(pd.IntervalDtype, tokenize_interval_dtype) -tokenize_handler.register(tzinfo, tokenize_pickled) +tokenize_handler.register(tzinfo, tokenize_pickled_with_cache) tokenize_handler.register(pd.arrays.DatetimeArray, tokenize_pandas_time_arrays) tokenize_handler.register(pd.arrays.TimedeltaArray, tokenize_pandas_time_arrays) tokenize_handler.register(pd.arrays.PeriodArray, tokenize_pandas_time_arrays) diff --git a/mars/serialize/core.pxd b/mars/serialize/core.pxd index a681503779..9af1638c0a 100644 --- a/mars/serialize/core.pxd +++ b/mars/serialize/core.pxd @@ -50,6 +50,7 @@ cpdef enum ExtendType: freq = 29 namedtuple = 30 regex = 31 + pickled = 32767 cdef class Identity: diff --git a/mars/serialize/core.pyx b/mars/serialize/core.pyx index efa0ed986d..1b58a24fd6 100644 --- a/mars/serialize/core.pyx +++ b/mars/serialize/core.pyx @@ -115,6 +115,7 @@ cdef class ValueType: freq = ExtendType.freq namedtuple = ExtendType.namedtuple regex = ExtendType.regex + pickled = ExtendType.pickled identity = Identity() diff --git a/mars/serialize/jsonserializer.pyx b/mars/serialize/jsonserializer.pyx index 70bf853e84..0243d8d0e7 100644 --- a/mars/serialize/jsonserializer.pyx +++ b/mars/serialize/jsonserializer.pyx @@ -79,6 +79,7 @@ cdef dict EXTEND_TYPE_TO_NAME = { ValueType.complex128: 'complex128', ValueType.namedtuple: 'namedtuple', ValueType.regex: 'regex', + ValueType.pickled: 'pickled', } @@ -447,6 +448,20 @@ cdef class JsonSerializeProvider(Provider): value = obj['value'] return to_offset(value) + cdef inline _serialize_pickled(self, value): + return { + 'type': 'pickled', + 'value': self._to_str(base64.b64encode(cloudpickle.dumps(value, protocol=self.pickle_protocol))), + } + + cdef inline _deserialize_pickled(self, obj, list callbacks): + value = obj['value'] + v = base64.b64decode(value) + + if v is not None: + return cloudpickle.loads(v) + return None + cdef inline object _serialize_typed_value(self, value, tp, bint weak_ref=False): if type(tp) not in (List, Tuple, Dict) and weak_ref: # not iterable, and is weak ref @@ -570,7 +585,10 @@ cdef class JsonSerializeProvider(Provider): elif callable(value): return self._serialize_function(value) else: - raise TypeError(f'Unknown type to serialize: {type(value)}') + try: + return self._serialize_pickled(value) + except: + raise TypeError(f'Unknown type to serialize: {type(value)}') from None cdef inline object _serialize_value(self, value, tp=None, bint weak_ref=False): if tp is None: @@ -702,6 +720,8 @@ cdef class JsonSerializeProvider(Provider): return self._deserialize_namedtuple(obj, callbacks) elif tp is ValueType.regex: return self._deserialize_regex(obj, callbacks) + elif tp is ValueType.pickled: + return self._deserialize_pickled(obj, callbacks) else: raise TypeError(f'Unknown type to deserialize {obj["type"]}') diff --git a/mars/serialize/pbserializer.pyx b/mars/serialize/pbserializer.pyx index dd0a093170..f60f4862fa 100644 --- a/mars/serialize/pbserializer.pyx +++ b/mars/serialize/pbserializer.pyx @@ -282,6 +282,13 @@ cdef class ProtobufSerializeProvider(Provider): cdef inline object _get_freq(self, obj): return to_offset(obj.freq) + cdef inline int _set_pickled(self, value, obj, tp=None) except -1: + obj.pickled = cloudpickle.dumps(value, protocol=self.pickle_protocol) + return 0 + + cdef inline object _get_pickled(self, obj): + return cloudpickle.loads(obj.pickled) + @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) @@ -570,7 +577,10 @@ cdef class ProtobufSerializeProvider(Provider): elif callable(value): self._set_function(value, obj) else: - raise TypeError(f'Unknown type to serialize: {type(value)}') + try: + self._set_pickled(value, obj) + except: + raise TypeError(f'Unknown type to serialize: {type(value)}') from None cdef inline void _set_value(cls, value, obj, tp=None, bint weak_ref=False) except *: if tp is None: @@ -865,6 +875,8 @@ cdef class ProtobufSerializeProvider(Provider): return ref(self._get_namedtuple(obj)) elif field == 'regex': return ref(self._get_regex(obj)) + elif field == 'pickled': + return ref(self._get_pickled(obj)) else: raise TypeError('Unknown type to deserialize') diff --git a/mars/serialize/protos/value.proto b/mars/serialize/protos/value.proto index 87afdd59e3..6afa756724 100644 --- a/mars/serialize/protos/value.proto +++ b/mars/serialize/protos/value.proto @@ -83,6 +83,7 @@ message Value { string freq = 24; // pandas.tseries.offsets.Tick bytes namedtuple = 25; RegexValue regex = 26; + bytes pickled = 32767; } } diff --git a/mars/serialize/tests/test_serialize.py b/mars/serialize/tests/test_serialize.py index 2072bc28d6..b2c8cf3001 100644 --- a/mars/serialize/tests/test_serialize.py +++ b/mars/serialize/tests/test_serialize.py @@ -63,6 +63,11 @@ nt = namedtuple('nt', 'a b') +class ClassToPickle: + def __init__(self, a): + self.a = a + + class Node1(Serializable): a = IdentityField('a', ValueType.string) b1 = Int8Field('b1') @@ -81,6 +86,7 @@ class Node1(Serializable): e = BoolField('e') f1 = KeyField('f1') f2 = AnyField('f2') + f3 = AnyField('f3') g = ReferenceField('g', 'Node2') h = ListField('h') i = ListField('i', ValueType.reference('self')) @@ -233,6 +239,7 @@ def testPBSerialize(self, *_): e=False, f1=Node2Entity(node2), f2=Node2Entity(node2), + f3=ClassToPickle(1285), g=Node2(a=[['1', '2'], ['3', '4']]), h=[[2, 3], node2, True, {1: node2}, np.datetime64('1066-10-13'), np.timedelta64(1, 'D'), np.complex64(1+2j), np.complex128(2+3j), @@ -282,6 +289,8 @@ def testPBSerialize(self, *_): self.assertEqual(node3.value.f1.a, d_node3.value.f1.a) self.assertIsNot(node3.value.f2, d_node3.value.f2) self.assertEqual(node3.value.f2.a, d_node3.value.f2.a) + self.assertIsNot(node3.value.f3, d_node3.value.f3) + self.assertEqual(node3.value.f3.a, d_node3.value.f3.a) self.assertIsNot(node3.value.g, d_node3.value.g) self.assertEqual(node3.value.g.a, d_node3.value.g.a) self.assertEqual(node3.value.h[0], d_node3.value.h[0]) @@ -331,6 +340,7 @@ def testJSONSerialize(self): e=False, f1=Node2Entity(node2), f2=Node2Entity(node2), + f3=ClassToPickle(1285), g=Node2(a=[['1', '2'], ['3', '4']]), h=[[2, 3], node2, True, {1: node2}, np.datetime64('1066-10-13'), np.timedelta64(1, 'D'), np.complex64(1+2j), np.complex128(2+3j), @@ -381,6 +391,8 @@ def testJSONSerialize(self): self.assertEqual(node3.value.f1.a, d_node3.value.f1.a) self.assertIsNot(node3.value.f2, d_node3.value.f2) self.assertEqual(node3.value.f2.a, d_node3.value.f2.a) + self.assertIsNot(node3.value.f3, d_node3.value.f3) + self.assertEqual(node3.value.f3.a, d_node3.value.f3.a) self.assertIsNot(node3.value.g, d_node3.value.g) self.assertEqual(node3.value.g.a, d_node3.value.g.a) self.assertEqual(node3.value.h[0], d_node3.value.h[0]) @@ -682,7 +694,11 @@ def testAttributeAsDict(self): self.assertIsInstance(d_node62.rl[0], Node6) def testException(self): - node1 = Node1(h=[object()]) + class Unserializable: + def __getstate__(self): + raise SystemError + + node1 = Node1(h=[Unserializable()]) pbs = ProtobufSerializeProvider() diff --git a/mars/serialize/tests/testser.proto b/mars/serialize/tests/testser.proto index 44c4dbc127..ae38fe7942 100644 --- a/mars/serialize/tests/testser.proto +++ b/mars/serialize/tests/testser.proto @@ -22,6 +22,7 @@ message Node1Def { bool e = 13; Value f1 = 14; Value f2 = 15; + Value f3 = 27; Node2Def g = 16; repeated Value h = 17; repeated Node1Def i = 20; diff --git a/mars/tests/test_utils.py b/mars/tests/test_utils.py index 2d139376f9..b2042bf4b1 100644 --- a/mars/tests/test_utils.py +++ b/mars/tests/test_utils.py @@ -132,13 +132,16 @@ class TestEnum(Enum): v = [df, df.index, df.columns, df['data'], pd.Categorical(list('ABCD'))] self.assertEqual(utils.tokenize(v), utils.tokenize(copy.deepcopy(v))) - non_tokenizable_cls = type('non_tokenizable_cls', (object,), {}) + class NonTokenizableCls: + def __getstate__(self): + raise SystemError + with self.assertRaises(TypeError): - utils.tokenize(non_tokenizable_cls()) + utils.tokenize(NonTokenizableCls()) class CustomizedTokenize(object): def __mars_tokenize__(self): - return id(type(self)), id(non_tokenizable_cls) + return id(type(self)), id(NonTokenizableCls) self.assertEqual(utils.tokenize(CustomizedTokenize()), utils.tokenize(CustomizedTokenize()))