Skip to content

Commit

Permalink
Dealt with some serialization issues
Browse files Browse the repository at this point in the history
- Need to handle exception instances embedded within others.

  I dealt with this in msgpack using a "default" option (essentially a
  msgpack/json form of reduce).

  For pickle, we're still creating instance pickles in this case. :/

- Use a python-msgpack option to produce tuples rather than
  lists.  The ZEO protocol uses tuples far more often than lists.

  This really mostly or entirely affects tests.

  Removed workarounds for some test code that expected tuples and
  added some for test code that expects lists. :)
  • Loading branch information
Jim Fulton committed Nov 13, 2016
1 parent 28f7a92 commit 70ed5e5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 25 deletions.
20 changes: 17 additions & 3 deletions src/ZEO/asyncio/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@

logger = logging.getLogger(__name__)

def encoder(protocol):
def encoder(protocol, server=False):
"""Return a non-thread-safe encoder
"""

if protocol[:1] == b'M':
from msgpack import packb
default = server_default if server else None
def encode(*args):
return packb(args, use_bin_type=True)
return packb(
args, use_bin_type=True, default=default)

return encode
else:
assert protocol[:1] == b'Z'
Expand Down Expand Up @@ -69,7 +72,7 @@ def decoder(protocol):
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8')
return unpackb(data, encoding='utf-8', use_list=False)
return msgpack_decode
else:
assert protocol[:1] == b'Z'
Expand Down Expand Up @@ -113,6 +116,17 @@ def pickle_server_decode(msg):
logger.error("can't decode message: %s" % short_repr(msg))
raise

def server_default(obj):
if isinstance(obj, Exception):
return reduce_exception(obj)
else:
return obj

def reduce_exception(exc):
class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args

_globals = globals()
_silly = ('__doc__',)

Expand Down
9 changes: 3 additions & 6 deletions src/ZEO/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import base
from .compat import asyncio, new_event_loop
from .marshal import server_decoder, encoder
from .marshal import server_decoder, encoder, reduce_exception

class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
Expand Down Expand Up @@ -70,7 +70,7 @@ def finish_connect(self, protocol_version):
logger.info("received handshake %r" %
str(protocol_version.decode('ascii')))
self.protocol_version = protocol_version
self.encode = encoder(protocol_version)
self.encode = encoder(protocol_version, True)
self.decode = server_decoder(protocol_version)
self.zeo_storage.notify_connected(self)
else:
Expand Down Expand Up @@ -135,10 +135,7 @@ def send_reply_threadsafe(self, message_id, result):
def send_error(self, message_id, exc, send_error=False):
"""Abstracting here so we can make this cleaner in the future
"""
class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__)
args = class_, exc.__dict__ or exc.args
self.send_reply(message_id, args, send_error, 2)
self.send_reply(message_id, reduce_exception(exc), send_error, 2)

def async(self, method, *args):
self.call_async(method, args)
Expand Down
29 changes: 15 additions & 14 deletions src/ZEO/asyncio/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class Base(object):

enc = b'Z'
seq_type = list

def setUp(self):
super(Base, self).setUp()
Expand All @@ -39,11 +40,7 @@ def unsized(self, data, unpickle=False):
data = data[2:]
self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle:
message = tuple(self.decode(message))
if isinstance(message[-1], list):
message = message[:-1] + (tuple(message[-1]),)
if isinstance(message[0], list):
message = (tuple(message[-1]),) + message[1:]
message = self.decode(message)
result.append(message)

if len(result) == 1:
Expand Down Expand Up @@ -205,15 +202,15 @@ def testClientBasics(self):
self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, maxtid)))
# Note load_before uses the oid as the message id.
self.respond(5, (b'data', b'a'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None))

# If we make another request, it will be satisfied from the cache:
loaded = self.load_before(b'1'*8, maxtid)
self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
self.assertFalse(transport.data)

# Let's send an invalidation:
self.send('invalidateTransaction', b'b'*8, [b'1'*8])
self.send('invalidateTransaction', b'b'*8, self.seq_type([b'1'*8]))

# Now, if we try to load current again, we'll make a server request.
loaded = self.load_before(b'1'*8, maxtid)
Expand All @@ -224,21 +221,21 @@ def testClientBasics(self):

self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, maxtid)))
self.respond(6, (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded2.result()), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))

# Loading non-current data may also be satisfied from cache
loaded = self.load_before(b'1'*8, b'b'*8)
self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, b'b'*8))
self.assertEqual(loaded.result(), (b'data', b'a'*8, b'b'*8))
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'c'*8)
self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8)

self.assertEqual(self.pop(), (7, False, 'loadBefore', (b'1'*8, b'_'*8)))
self.respond(7, (b'data0', b'^'*8, b'_'*8))
self.assertEqual(tuple(loaded.result()), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))

# When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, data, resolved)
Expand Down Expand Up @@ -549,7 +546,8 @@ def test_invalidations_while_verifying(self):
self.pop(4)
self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
self.respond(2, b'a'*8)
self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False)
self.send('invalidateTransaction', b'c'*8, self.seq_type([b'1'*8]),
no_output=False)
self.assertEqual(self.pop(), (3, False, 'get_info', ()))

# We'll disconnect:
Expand All @@ -567,7 +565,8 @@ def test_invalidations_while_verifying(self):
self.pop(4)
self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
self.respond(2, b'c'*8)
self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False)
self.send('invalidateTransaction', b'e'*8, self.seq_type([b'1'*8]),
no_output=False)
self.assertEqual(self.pop(), (3, False, 'get_info', ()))

def test_flow_control(self):
Expand Down Expand Up @@ -691,6 +690,7 @@ def test_heartbeat(self):

class MsgpackClientTests(ClientTests):
enc = b'M'
seq_type = tuple

class MemoryCache(object):

Expand Down Expand Up @@ -830,6 +830,7 @@ def test_invalid_methods(self):

class MsgpackServerTests(ServerTests):
enc = b'M'
seq_type = tuple

def server_protocol(msgpack,
zeo_storage=None,
Expand Down
4 changes: 2 additions & 2 deletions src/ZEO/tests/protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ A current client should be able to connect to a old server:
>>> import ZEO, ZODB.blob, transaction
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> wait_connected(db.storage)
>>> str(db.storage.protocol_version.decode('ascii'))
'Z4'
>>> str(db.storage.protocol_version.decode('ascii'))[1:]
'4'

>>> conn = db.open()
>>> conn.root().x = 0
Expand Down

0 comments on commit 70ed5e5

Please sign in to comment.