/
marshal.py
171 lines (140 loc) · 4.76 KB
/
marshal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Support for marshaling ZEO messages
Not to be confused with marshaling objects in ZODB.
We currently use pickle. In the future, we may use a
Python-independent format, or possibly a minimal pickle subset.
"""
import logging
from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY
from ..shortrepr import short_repr
PY2 = not PY3
logger = logging.getLogger(__name__)
def encoder(protocol, server=False):
"""Return a non-thread-safe encoder
"""
if protocol[:1] == b'M':
from msgpack import packb
default = server_default if server else None
def encode(*args):
return packb(
args, use_bin_type=True, default=default)
return encode
else:
assert protocol[:1] == b'Z'
f = BytesIO()
getvalue = f.getvalue
seek = f.seek
truncate = f.truncate
pickler = Pickler(f, 3)
pickler.fast = 1
dump = pickler.dump
def encode(*args):
seek(0)
truncate()
dump(args)
return getvalue()
return encode
def encode(*args):
return encoder(b'Z')(*args)
def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8', use_list=False)
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode
def pickle_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global
try:
# PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
else:
assert protocol[:1] == b'Z'
return pickle_server_decode
def pickle_server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global
try:
# PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_default(obj):
if isinstance(obj, Exception):
return reduce_exception(obj)
else:
return obj
def reduce_exception(exc):
class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args
_globals = globals()
_silly = ('__doc__',)
exception_type_type = type(Exception)
_SAFE_MODULE_NAMES = (
'ZopeUndo.Prefix', 'zodbpickle',
'builtins', 'copy_reg', '__builtin__',
)
def find_global(module, name):
"""Helper for message unpickler"""
try:
m = __import__(module, _globals, _globals, _silly)
except ImportError as msg:
raise ImportError("import error %s: %s" % (module, msg))
try:
r = getattr(m, name)
except AttributeError:
raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES)
if safe:
return r
# TODO: is there a better way to do this?
if type(r) == exception_type_type and issubclass(r, Exception):
return r
raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name):
"""Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES:
raise ImportError("Module not allowed: %s" % (module,))
try:
m = __import__(module, _globals, _globals, _silly)
except ImportError as msg:
raise ImportError("import error %s: %s" % (module, msg))
try:
r = getattr(m, name)
except AttributeError:
raise ImportError("module %s has no global %s" % (module, name))
return r