Skip to content

Commit

Permalink
accept bytes in comms
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin Peter committed May 27, 2024
1 parent 7a8ad42 commit 872f5a3
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 25 deletions.
75 changes: 61 additions & 14 deletions spyder_kernels/comms/commbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Comms transmit data in a list of buffers, and in a json-able dictionnary.
Here, we only support json to avoid issues of compatibility between python
versions.
versions. In the abstraction below, buffers is used to send bytes.
The messages exchanged have the following msg_dict:
Expand All @@ -36,6 +36,8 @@
'settings': A dictionnary of settings,
'call_args': The function args,
'call_kwargs': The function kwargs,
'buffers_args_idx': The args index that are bytes,
'buffers_kwargs_keys': the kwargs keys that are bytes
}
- If the 'settings' has `'blocking' = True`, a reply is sent.
(spyder_msg_type = 'remote_call_reply'):
Expand Down Expand Up @@ -225,7 +227,9 @@ def remote_call(self, comm_id=None, callback=None, **settings):
return RemoteCallFactory(self, comm_id, callback, **settings)

# ---- Private -----
def _send_message(self, spyder_msg_type, content=None, comm_id=None):
def _send_message(
self, spyder_msg_type, content=None, comm_id=None, buffers=None
):
"""
Publish custom messages to the other side.
Expand All @@ -249,7 +253,7 @@ def _send_message(self, spyder_msg_type, content=None, comm_id=None):
'content': content,
}

self._comms[comm_id]['comm'].send(msg_dict)
self._comms[comm_id]['comm'].send(msg_dict, buffers=buffers)

@property
def _comm_name(self):
Expand Down Expand Up @@ -302,21 +306,34 @@ def _comm_message(self, msg):
# Get message dict
msg_dict = msg['content']['data']
spyder_msg_type = msg_dict['spyder_msg_type']
buffers = msg['buffers']

if spyder_msg_type in self._message_handlers:
self._message_handlers[spyder_msg_type](msg_dict)
self._message_handlers[spyder_msg_type](msg_dict, buffers)
else:
logger.debug("No such spyder message type: %s" % spyder_msg_type)

def _handle_remote_call(self, msg):
def _handle_remote_call(self, msg, buffers):
"""Handle a remote call."""
msg_dict = msg['content']
self.on_incoming_call(msg_dict)
try:
# read buffers
args = msg_dict['call_args']
kwargs = msg_dict['call_kwargs']

if buffers:
for idx in msg_dict['buffers_args_idx']:
args[idx] = buffers.pop(0)
for name in msg_dict['buffers_kwargs_keys']:
kwargs[name] = buffers.pop(0)
assert len(buffers) == 0

return_value = self._remote_callback(
msg_dict['call_name'],
msg_dict['call_args'],
msg_dict['call_kwargs'])
args,
kwargs
)
self._set_call_return_value(msg_dict, return_value)
except Exception:
exc_infos = CommsErrorWrapper(
Expand Down Expand Up @@ -350,6 +367,10 @@ def _set_call_return_value(self, call_dict, return_value, is_error=False):
if not send_reply:
# Nothing to send back
return
buffers = None
if isinstance(return_value, bytes):
buffers = [return_value]
return_value = None
content = {
'is_error': is_error,
'call_id': call_dict['call_id'],
Expand All @@ -358,7 +379,8 @@ def _set_call_return_value(self, call_dict, return_value, is_error=False):
}

self._send_message(
'remote_call_reply', content=content, comm_id=self.calling_comm_id
'remote_call_reply', content=content, comm_id=self.calling_comm_id,
buffers=buffers
)

def _register_call(self, call_dict, callback=None):
Expand All @@ -379,11 +401,11 @@ def on_incoming_call(self, call_dict):
"""A call was received"""
pass

def _send_call(self, call_dict, comm_id):
def _send_call(self, call_dict, comm_id, buffers=None):
"""Send call."""
call_dict = self.on_outgoing_call(call_dict)
self._send_message(
'remote_call', content=call_dict, comm_id=comm_id
'remote_call', content=call_dict, comm_id=comm_id, buffers=buffers
)

def _get_call_return_value(self, call_dict, comm_id):
Expand Down Expand Up @@ -426,7 +448,7 @@ def _wait_reply(self, comm_id, call_id, call_name, timeout):
"""
raise NotImplementedError

def _handle_remote_call_reply(self, msg_dict):
def _handle_remote_call_reply(self, msg_dict, buffers):
"""
A blocking call received a reply.
"""
Expand All @@ -435,6 +457,12 @@ def _handle_remote_call_reply(self, msg_dict):
call_name = content['call_name']
is_error = content['is_error']
return_value = content['call_return_value']
if is_error:
return_value = CommsErrorWrapper.from_json(return_value)
elif buffers:
assert len(buffers) == 1
return_value = buffers[0]
content['call_return_value'] = return_value

# Unexpected reply
if call_id not in self._reply_waitlist:
Expand Down Expand Up @@ -463,13 +491,13 @@ def _async_error(self, error_wrapper):
"""
Handle an error that was raised on the other side asyncronously.
"""
CommsErrorWrapper.from_json(error_wrapper).print_error()
error_wrapper.print_error()

def _sync_error(self, error_wrapper):
"""
Handle an error that was raised on the other side syncronously.
"""
CommsErrorWrapper.from_json(error_wrapper).raise_error()
error_wrapper.raise_error()


class RemoteCallFactory:
Expand Down Expand Up @@ -511,6 +539,23 @@ def __call__(self, *args, **kwargs):
"""
blocking = 'blocking' in self._settings and self._settings['blocking']
self._settings['send_reply'] = blocking or self._callback is not None

# put all bytes in a buffer
buffers = []
buffers_args_idx = []
args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, bytes):
buffers.append(arg)
buffers_args_idx.append(i)
args[i] = None
buffers_kwargs_keys = []
for name in kwargs:
arg = kwargs[name]
if isinstance(arg, bytes):
buffers.append(arg)
buffers_kwargs_keys.append(name)
kwargs[name] = None

call_id = uuid.uuid4().hex
call_dict = {
Expand All @@ -519,6 +564,8 @@ def __call__(self, *args, **kwargs):
'settings': self._settings,
'call_args': args,
'call_kwargs': kwargs,
'buffers_args_idx': buffers_args_idx,
'buffers_kwargs_keys': buffers_kwargs_keys
}

if not self._comms_wrapper.is_open(self._comm_id):
Expand All @@ -528,6 +575,6 @@ def __call__(self, *args, **kwargs):
logger.debug("Call to unconnected comm: %s" % self._name)
return
self._comms_wrapper._register_call(call_dict, self._callback)
self._comms_wrapper._send_call(call_dict, self._comm_id)
self._comms_wrapper._send_call(call_dict, self._comm_id, buffers)
return self._comms_wrapper._get_call_return_value(
call_dict, self._comm_id)
2 changes: 1 addition & 1 deletion spyder_kernels/comms/frontendcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _async_error(self, error_wrapper):
"""
Send an async error back to the frontend to be displayed.
"""
self.remote_call()._async_error(error_wrapper)
self.remote_call()._async_error(error_wrapper.to_json())

def _register_comm(self, comm):
"""
Expand Down
10 changes: 3 additions & 7 deletions spyder_kernels/console/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""

# Standard library imports
import base64
import faulthandler
import json
import logging
Expand Down Expand Up @@ -363,19 +362,16 @@ def get_value(self, name, encoded=False):
ns = self.shell._get_current_namespace()
value = ns[name]
if encoded:
# Encode with cloudpickle and base64
value = base64.b64encode(
cloudpickle.dumps(value)
).decode()

# Encode with cloudpickle
value = cloudpickle.dumps(value)
return value

@comm_handler
def set_value(self, name, value, encoded=False):
"""Set the value of a variable"""
if encoded:
# Decode_value
value = cloudpickle.loads(base64.b64decode(value.encode()))
value = cloudpickle.loads(value)
ns = self.shell._get_reference_namespace(name)
ns[name] = value
self.log.debug(ns)
Expand Down
3 changes: 0 additions & 3 deletions spyder_kernels/console/tests/test_console_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import inspect
import uuid
from collections import namedtuple
import cloudpickle
import base64

# Test imports
import pytest
Expand Down Expand Up @@ -348,7 +346,6 @@ def test_set_value(kernel):
name = 'a'
asyncio.run(kernel.do_execute('a = 0', True))
value = 10
# Encode with cloudpickle and base64
kernel.set_value(name, value)
log_text = get_log_text(kernel)
assert "'__builtin__': <module " in log_text
Expand Down

0 comments on commit 872f5a3

Please sign in to comment.