Skip to content

Commit

Permalink
Implemented msgpack as an optional ZEO message encoding with basic te…
Browse files Browse the repository at this point in the history
…sts.
  • Loading branch information
Jim Fulton committed Nov 12, 2016
1 parent c318342 commit b31fed1
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 95 deletions.
7 changes: 7 additions & 0 deletions README.rst
Expand Up @@ -289,6 +289,13 @@ client-conflict-resolution
Flag indicating that clients should perform conflict
resolution. This option defaults to false.

msgpack
Use msgpack to serialize and de-serialize ZEO protocol messages.

An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).

Server SSL configuration
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
8 changes: 6 additions & 2 deletions setup.py
Expand Up @@ -36,7 +36,7 @@
'zope.interface',
]

tests_require = ['zope.testing', 'manuel', 'random2', 'mock']
tests_require = ['zope.testing', 'manuel', 'random2', 'mock', 'msgpack-python']

if sys.version_info[:2] < (3, ):
install_requires.extend(('futures', 'trollius'))
Expand Down Expand Up @@ -128,7 +128,11 @@ def emit(self, record):
classifiers = classifiers,
test_suite="__main__.alltests", # to support "setup.py test"
tests_require = tests_require,
extras_require = dict(test=tests_require, uvloop=['uvloop >=0.5.1']),
extras_require = dict(
test=tests_require,
uvloop=['uvloop >=0.5.1'],
msgpack=['msgpack-python'],
),
install_requires = install_requires,
zip_safe = False,
entry_points = """
Expand Down
3 changes: 2 additions & 1 deletion src/ZEO/StorageServer.py
Expand Up @@ -663,6 +663,7 @@ def __init__(self, addr, storages,
ssl=None,
client_conflict_resolution=False,
Acceptor=Acceptor,
msgpack=False,
):
"""StorageServer constructor.
Expand Down Expand Up @@ -757,7 +758,7 @@ def __init__(self, addr, storages,
self.client_conflict_resolution = client_conflict_resolution

if addr is not None:
self.acceptor = Acceptor(self, addr, ssl)
self.acceptor = Acceptor(self, addr, ssl, msgpack)
if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr
else:
Expand Down
4 changes: 1 addition & 3 deletions src/ZEO/asyncio/base.py
Expand Up @@ -10,8 +10,6 @@
from struct import unpack
import sys

from .marshal import encoder

logger = logging.getLogger(__name__)

INET_FAMILIES = socket.AF_INET, socket.AF_INET6
Expand Down Expand Up @@ -129,13 +127,13 @@ def data_received(self, data):
self.getting_size = True
self.message_received(collected)
except Exception:
#import traceback; traceback.print_exc()
logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size)

def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.encode = encoder()
self.finish_connect(protocol_version)

def call_async(self, method, args):
Expand Down
33 changes: 22 additions & 11 deletions src/ZEO/asyncio/client.py
Expand Up @@ -13,7 +13,7 @@

from . import base
from .compat import asyncio, new_event_loop
from .marshal import decode
from .marshal import encoder, decoder

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +63,7 @@ class Protocol(base.Protocol):
# One place where special care was required was in cache setup on
# connect. See finish connect below.

protocols = b'Z309', b'Z310', b'Z3101', b'Z4', b'Z5'
protocols = b'309', b'310', b'3101', b'4', b'5'

def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1,
Expand Down Expand Up @@ -150,6 +150,8 @@ def connection_lost(self, exc):
# We have to be careful processing the futures, because
# exception callbacks might modufy them.
for f in self.pop_futures():
if isinstance(f, tuple):
continue
f.set_exception(ClientDisconnected(exc or 'connection lost'))
self.closed = True
self.client.disconnected(self)
Expand All @@ -165,13 +167,17 @@ def finish_connect(self, protocol_version):
# lastTid before processing (and possibly missing) subsequent
# invalidations.

self.protocol_version = min(protocol_version, self.protocols[-1])

if self.protocol_version not in self.protocols:
version = min(protocol_version[1:], self.protocols[-1])
if version not in self.protocols:
self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version))
return

self.protocol_version = protocol_version[:1] + version
self.encode = encoder(protocol_version)
self.decode = decoder(protocol_version)
self.heartbeat_bytes = self.encode(-1, 0, '.reply', None)

self._write(self.protocol_version)

credentials = (self.credentials,) if self.credentials else ()
Expand Down Expand Up @@ -199,9 +205,12 @@ def finish_connect(self, protocol_version):

exception_type_type = type(Exception)
def message_received(self, data):
msgid, async, name, args = decode(data)
msgid, async, name, args = self.decode(data)
if name == '.reply':
future = self.futures.pop(msgid)
if isinstance(future, tuple):
future = self.futures.pop(future)

if (async): # ZEO 5 exception
class_, args = args
factory = exc_factories.get(class_)
Expand Down Expand Up @@ -245,13 +254,15 @@ def fut(self, method, *args):

def load_before(self, oid, tid):
# Special-case loadBefore, so we collapse outstanding requests
message_id = (oid, tid)
future = self.futures.get(message_id)
oid_tid = (oid, tid)
future = self.futures.get(oid_tid)
if future is None:
future = asyncio.Future(loop=self.loop)
self.futures[message_id] = future
self.futures[oid_tid] = future
self.message_id += 1
self.futures[self.message_id] = oid_tid
self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid)))
self.encode(self.message_id, False, 'loadBefore', (oid, tid)))
return future

# Methods called by the server.
Expand All @@ -267,7 +278,7 @@ def load_before(self, oid, tid):

def heartbeat(self, write=True):
if write:
self._write(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.')
self._write(self.heartbeat_bytes)
self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat)

Expand Down
34 changes: 30 additions & 4 deletions src/ZEO/asyncio/marshal.py
Expand Up @@ -26,10 +26,18 @@

logger = logging.getLogger(__name__)

def encoder():
def encoder(protocol):
"""Return a non-thread-safe encoder
"""

if protocol[:1] == b'M':
from msgpack import packb
def encode(*args):
return packb(args, use_bin_type=True)
return encode
else:
assert protocol[:1] == b'Z'

if PY3 or PYPY:
f = BytesIO()
getvalue = f.getvalue
Expand All @@ -54,9 +62,20 @@ def encode(*args):

def encode(*args):

return encoder()(*args)
return encoder(b'Z')(*args)

def decode(msg):
def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8')
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode

def pickle_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global
Expand All @@ -71,7 +90,14 @@ def decode(msg):
logger.error("can't decode message: %s" % short_repr(msg))
raise

def server_decode(msg):
def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
else:
assert protocol[:1] == b'Z'
return pickle_server_decode

def pickle_server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global
Expand Down
5 changes: 3 additions & 2 deletions src/ZEO/asyncio/mtacceptor.py
Expand Up @@ -76,13 +76,14 @@ class Acceptor(asyncore.dispatcher):
And creates a separate thread for each.
"""

def __init__(self, storage_server, addr, ssl):
def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server
self.addr = addr
self.__socket_map = {}
asyncore.dispatcher.__init__(self, map=self.__socket_map)

self.ssl_context = ssl
self.msgpack = msgpack
self._open_socket()

def _open_socket(self):
Expand Down Expand Up @@ -165,7 +166,7 @@ def handle_accept(self):
def run():
loop = new_event_loop()
zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(loop, self.addr, zs)
protocol = ServerProtocol(loop, self.addr, zs, self.msgpack)
protocol.stop = loop.stop

if self.ssl_context is None:
Expand Down
29 changes: 19 additions & 10 deletions src/ZEO/asyncio/server.py
Expand Up @@ -11,13 +11,13 @@

from . import base
from .compat import asyncio, new_event_loop
from .marshal import server_decode
from .marshal import server_decoder, encoder

class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
"""

protocols = (b'Z5', )
protocols = (b'5', )

name = 'server protocol'
methods = set(('register', ))
Expand All @@ -26,12 +26,16 @@ class ServerProtocol(base.Protocol):
ZODB.POSException.POSKeyError,
)

def __init__(self, loop, addr, zeo_storage):
def __init__(self, loop, addr, zeo_storage, msgpack):
"""Create a server's client interface
"""
super(ServerProtocol, self).__init__(loop, addr)
self.zeo_storage = zeo_storage

self.announce_protocol = (
(b'M' if msgpack else b'Z') + best_protocol_version
)

closed = False
def close(self):
logger.debug("Closing server protocol")
Expand All @@ -44,7 +48,7 @@ def close(self):
def connection_made(self, transport):
self.connected = True
super(ServerProtocol, self).connection_made(transport)
self._write(best_protocol_version)
self._write(self.announce_protocol)

def connection_lost(self, exc):
self.connected = False
Expand All @@ -61,10 +65,13 @@ def finish_connect(self, protocol_version):
self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii"))
self.close()
else:
if protocol_version in self.protocols:
version = protocol_version[1:]
if version in self.protocols:
logger.info("received handshake %r" %
str(protocol_version.decode('ascii')))
self.protocol_version = protocol_version
self.encode = encoder(protocol_version)
self.decode = server_decoder(protocol_version)
self.zeo_storage.notify_connected(self)
else:
logger.error("bad handshake %s" % short_repr(protocol_version))
Expand All @@ -79,7 +86,7 @@ def call_soon_threadsafe(self, func, *args):

def message_received(self, message):
try:
message_id, async, name, args = server_decode(message)
message_id, async, name, args = self.decode(message)
except Exception:
logger.exception("Can't deserialize message")
self.close()
Expand Down Expand Up @@ -144,8 +151,8 @@ def async_threadsafe(self, method, *args):
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols

def new_connection(loop, addr, socket, zeo_storage):
protocol = ServerProtocol(loop, addr, zeo_storage)
def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket)
asyncio.async(cr, loop=loop)

Expand Down Expand Up @@ -213,10 +220,11 @@ def error(self, exc_info):

class Acceptor(object):

def __init__(self, storage_server, addr, ssl):
def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server
self.addr = addr
self.ssl_context = ssl
self.msgpack = msgpack
self.event_loop = loop = new_event_loop()

if isinstance(addr, tuple):
Expand All @@ -243,7 +251,8 @@ def factory(self):
try:
logger.debug("Accepted connection")
zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(self.event_loop, self.addr, zs)
protocol = ServerProtocol(
self.event_loop, self.addr, zs, self.msgpack)
except Exception:
logger.exception("Failure in protocol factory")

Expand Down

0 comments on commit b31fed1

Please sign in to comment.