Skip to content

Commit

Permalink
Implement EmbeddedKernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
epatters committed Sep 10, 2012
1 parent e367f3e commit 8d53a55
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 48 deletions.
111 changes: 110 additions & 1 deletion IPython/embedded/ipkernel.py
Expand Up @@ -11,12 +11,121 @@
# Imports # Imports
#----------------------------------------------------------------------------- #-----------------------------------------------------------------------------


# Standard library imports
from contextlib import contextmanager
import logging
import sys

# Local imports. # Local imports.
from IPython.embedded.socket import DummySocket
from IPython.utils.jsonutil import json_clean
from IPython.utils.traitlets import Any, Instance, List
from IPython.zmq.ipkernel import Kernel from IPython.zmq.ipkernel import Kernel


#----------------------------------------------------------------------------- #-----------------------------------------------------------------------------
# Main kernel class # Main kernel class
#----------------------------------------------------------------------------- #-----------------------------------------------------------------------------


class EmbeddedKernel(Kernel): class EmbeddedKernel(Kernel):
pass
#-------------------------------------------------------------------------
# EmbeddedKernel interface
#-------------------------------------------------------------------------

frontends = List(
Instance('IPython.embedded.kernelmanager.EmbeddedKernelManager'))

raw_input_str = Any()
stdout = Any()
stderr = Any()

#-------------------------------------------------------------------------
# Kernel interface
#-------------------------------------------------------------------------

shell_streams = List()
control_stream = Any()
iopub_socket = Instance(DummySocket, ())
stdin_socket = Instance(DummySocket, ())

def __init__(self, **traits):
# When an InteractiveShell is instantiated by our base class, it binds
# the current values of sys.stdout and sys.stderr.
with self._redirected_io():
super(EmbeddedKernel, self).__init__(**traits)

self.iopub_socket.on_trait_change(self._io_dispatch, 'message_sent')

def execute_request(self, stream, ident, parent):
""" Override for temporary IO redirection. """
with self._redirected_io():
super(EmbeddedKernel, self).execute_request(stream, ident, parent)

def start(self):
""" Override registration of dispatchers for streams. """
self.shell.exit_now = False

def _abort_queue(self, stream):
""" The embedded kernel don't abort requests. """
pass

def _raw_input(self, prompt, ident, parent):
# Flush output before making the request.
self.raw_input_str = None
sys.stderr.flush()
sys.stdout.flush()

# Send the input request.
content = json_clean(dict(prompt=prompt))
msg = self.session.msg(u'input_request', content, parent)
for frontend in self.frontends:
if frontend.session.session == parent['header']['session']:
frontend.stdin_channel.call_handlers(msg)
break
else:
log.error('No frontend found for raw_input request')
return str()

# Await a response.
while self.raw_input_str is None:
frontend.stdin_channel.process_events()
return self.raw_input_str

#-------------------------------------------------------------------------
# Protected interface
#-------------------------------------------------------------------------

@contextmanager
def _redirected_io(self):
""" Temporarily redirect IO to the kernel.
"""
sys_stdout, sys_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = self.stdout, self.stderr
yield
sys.stdout, sys.stderr = sys_stdout, sys_stderr

#------ Trait change handlers --------------------------------------------

def _io_dispatch(self):
""" Called when a message is sent to the IO socket.
"""
ident, msg = self.session.recv(self.iopub_socket, copy=False)
for frontend in self.frontends:
frontend.sub_channel.call_handlers(msg)

#------ Trait initializers -----------------------------------------------

def _log_default(self):
return logging.getLogger(__name__)

def _session_default(self):
from IPython.zmq.session import Session
return Session(config=self.config)

def _stdout_default(self):
from IPython.zmq.iostream import OutStream
return OutStream(self.session, self.iopub_socket, u'stdout')

def _stderr_default(self):
from IPython.zmq.iostream import OutStream
return OutStream(self.session, self.iopub_socket, u'stderr')
55 changes: 50 additions & 5 deletions IPython/embedded/kernelmanager.py
Expand Up @@ -13,6 +13,7 @@


# Local imports. # Local imports.
from IPython.config.loader import Config from IPython.config.loader import Config
from IPython.embedded.socket import DummySocket
from IPython.utils.traitlets import HasTraits, Any, Instance, Type from IPython.utils.traitlets import HasTraits, Any, Instance, Type


#----------------------------------------------------------------------------- #-----------------------------------------------------------------------------
Expand Down Expand Up @@ -77,6 +78,10 @@ class ShellEmbeddedChannel(EmbeddedChannel):
# flag for whether execute requests should be allowed to call raw_input # flag for whether execute requests should be allowed to call raw_input
allow_stdin = True allow_stdin = True


#--------------------------------------------------------------------------
# ShellChannel interface
#--------------------------------------------------------------------------

def execute(self, code, silent=False, store_history=True, def execute(self, code, silent=False, store_history=True,
user_variables=[], user_expressions={}, allow_stdin=None): user_variables=[], user_expressions={}, allow_stdin=None):
"""Execute code in the kernel. """Execute code in the kernel.
Expand Down Expand Up @@ -115,7 +120,15 @@ def execute(self, code, silent=False, store_history=True,
------- -------
The msg_id of the message sent. The msg_id of the message sent.
""" """
raise NotImplementedError if allow_stdin is None:
allow_stdin = self.allow_stdin
content = dict(code=code, silent=silent, store_history=store_history,
user_variables=user_variables,
user_expressions=user_expressions,
allow_stdin=allow_stdin)
msg = self.manager.session.msg('execute_request', content)
self._dispatch_to_kernel(msg)
return msg['header']['msg_id']


def complete(self, text, line, cursor_pos, block=None): def complete(self, text, line, cursor_pos, block=None):
"""Tab complete text in the kernel's namespace. """Tab complete text in the kernel's namespace.
Expand All @@ -137,7 +150,10 @@ def complete(self, text, line, cursor_pos, block=None):
------- -------
The msg_id of the message sent. The msg_id of the message sent.
""" """
raise NotImplementedError content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
msg = self.manager.session.msg('complete_request', content)
self._dispatch_to_kernel(msg)
return msg['header']['msg_id']


def object_info(self, oname, detail_level=0): def object_info(self, oname, detail_level=0):
"""Get metadata information about an object. """Get metadata information about an object.
Expand All @@ -153,7 +169,10 @@ def object_info(self, oname, detail_level=0):
------- -------
The msg_id of the message sent. The msg_id of the message sent.
""" """
raise NotImplementedError content = dict(oname=oname, detail_level=detail_level)
msg = self.manager.session.msg('object_info_request', content)
self._dispatch_to_kernel(msg)
return msg['header']['msg_id']


def history(self, raw=True, output=False, hist_access_type='range', **kwds): def history(self, raw=True, output=False, hist_access_type='range', **kwds):
"""Get entries from the history list. """Get entries from the history list.
Expand Down Expand Up @@ -187,7 +206,11 @@ def history(self, raw=True, output=False, hist_access_type='range', **kwds):
------- -------
The msg_id of the message sent. The msg_id of the message sent.
""" """
raise NotImplementedError content = dict(raw=raw, output=output,
hist_access_type=hist_access_type, **kwds)
msg = self.manager.session.msg('history_request', content)
self._dispatch_to_kernel(msg)
return msg['header']['msg_id']


def shutdown(self, restart=False): def shutdown(self, restart=False):
""" Request an immediate kernel shutdown. """ Request an immediate kernel shutdown.
Expand All @@ -197,6 +220,25 @@ def shutdown(self, restart=False):
# FIXME: What to do here? # FIXME: What to do here?
raise NotImplementedError('Shutdown not supported for embedded kernel') raise NotImplementedError('Shutdown not supported for embedded kernel')


#--------------------------------------------------------------------------
# Protected interface
#--------------------------------------------------------------------------

def _dispatch_to_kernel(self, msg):
""" Send a message to the kernel and handle a reply.
"""
kernel = self.manager.kernel
if kernel is None:
raise RuntimeError('Cannot send request. No kernel exists.')

stream = DummySocket()
self.manager.session.send(stream, msg)
msg_parts = stream.recv_multipart()
kernel.dispatch_shell(stream, msg_parts)

idents, reply_msg = self.manager.session.recv(stream, copy=False)
self.call_handlers_later(reply_msg)



class SubEmbeddedChannel(EmbeddedChannel): class SubEmbeddedChannel(EmbeddedChannel):
"""The SUB channel which listens for messages that the kernel publishes. """The SUB channel which listens for messages that the kernel publishes.
Expand All @@ -216,7 +258,10 @@ class StdInEmbeddedChannel(EmbeddedChannel):
def input(self, string): def input(self, string):
""" Send a string of raw input to the kernel. """ Send a string of raw input to the kernel.
""" """
raise NotImplementedError kernel = self.manager.kernel
if kernel is None:
raise RuntimeError('Cannot send input reply. No kernel exists.')
kernel.raw_input_str = string




class HBEmbeddedChannel(EmbeddedChannel): class HBEmbeddedChannel(EmbeddedChannel):
Expand Down
43 changes: 43 additions & 0 deletions IPython/embedded/socket.py
@@ -0,0 +1,43 @@
""" Defines a dummy socket implementing (part of) the zmq.Socket interface. """

#-----------------------------------------------------------------------------
# Copyright (C) 2012 The IPython Development Team
#
# Distributed under the terms of the BSD License. The full license is in
# the file COPYING, distributed as part of this software.
#-----------------------------------------------------------------------------

#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------

# Standard library imports.
import Queue

# System library imports.
import zmq

# Local imports.
from IPython.utils.traitlets import HasTraits, Instance, Int

#-----------------------------------------------------------------------------
# Dummy socket class
#-----------------------------------------------------------------------------

class DummySocket(HasTraits):
""" A dummy socket implementing (part of) the zmq.Socket interface. """

queue = Instance(Queue.Queue, ())
message_sent = Int(0) # Should be an Event

#-------------------------------------------------------------------------
# zmq.Socket interface
#-------------------------------------------------------------------------

def recv_multipart(self, flags=0, copy=True, track=False):
return self.queue.get_nowait()

def send_multipart(self, msg_parts, flags=0, copy=True, track=False):
msg_parts = map(zmq.Message, msg_parts)
self.queue.put_nowait(msg_parts)
self.message_sent += 1
4 changes: 2 additions & 2 deletions IPython/zmq/datapub.py
Expand Up @@ -15,7 +15,7 @@
from IPython.config import Configurable from IPython.config import Configurable


from IPython.utils.jsonutil import json_clean from IPython.utils.jsonutil import json_clean
from IPython.utils.traitlets import Instance, Dict, CBytes from IPython.utils.traitlets import Any, Instance, Dict, CBytes


from IPython.zmq.serialize import serialize_object from IPython.zmq.serialize import serialize_object
from IPython.zmq.session import Session, extract_header from IPython.zmq.session import Session, extract_header
Expand All @@ -29,7 +29,7 @@ class ZMQDataPublisher(Configurable):


topic = topic = CBytes(b'datapub') topic = topic = CBytes(b'datapub')
session = Instance(Session) session = Instance(Session)
pub_socket = Instance('zmq.Socket') pub_socket = Any()
parent_header = Dict({}) parent_header = Dict({})


def set_parent(self, parent): def set_parent(self, parent):
Expand Down
4 changes: 2 additions & 2 deletions IPython/zmq/displayhook.py
Expand Up @@ -3,7 +3,7 @@


from IPython.core.displayhook import DisplayHook from IPython.core.displayhook import DisplayHook
from IPython.utils.jsonutil import encode_images from IPython.utils.jsonutil import encode_images
from IPython.utils.traitlets import Instance, Dict from IPython.utils.traitlets import Any, Instance, Dict
from session import extract_header, Session from session import extract_header, Session


class ZMQDisplayHook(object): class ZMQDisplayHook(object):
Expand Down Expand Up @@ -37,7 +37,7 @@ class ZMQShellDisplayHook(DisplayHook):
topic=None topic=None


session = Instance(Session) session = Instance(Session)
pub_socket = Instance('zmq.Socket') pub_socket = Any()
parent_header = Dict({}) parent_header = Dict({})


def set_parent(self, parent): def set_parent(self, parent):
Expand Down
31 changes: 0 additions & 31 deletions IPython/zmq/ipkernel.py
Expand Up @@ -644,7 +644,6 @@ def clear_request(self, stream, idents, parent):
# Protected interface # Protected interface
#--------------------------------------------------------------------------- #---------------------------------------------------------------------------



def _wrap_exception(self, method=None): def _wrap_exception(self, method=None):
# import here, because _wrap_exception is only used in parallel, # import here, because _wrap_exception is only used in parallel,
# and parallel has higher min pyzmq version # and parallel has higher min pyzmq version
Expand Down Expand Up @@ -739,36 +738,6 @@ def _complete(self, msg):
cpos = len(c['line']) cpos = len(c['line'])
return self.shell.complete(c['text'], c['line'], cpos) return self.shell.complete(c['text'], c['line'], cpos)


def _object_info(self, context):
symbol, leftover = self._symbol_from_context(context)
if symbol is not None and not leftover:
doc = getattr(symbol, '__doc__', '')
else:
doc = ''
object_info = dict(docstring = doc)
return object_info

def _symbol_from_context(self, context):
if not context:
return None, context

base_symbol_string = context[0]
symbol = self.shell.user_ns.get(base_symbol_string, None)
if symbol is None:
symbol = __builtin__.__dict__.get(base_symbol_string, None)
if symbol is None:
return None, context

context = context[1:]
for i, name in enumerate(context):
new_symbol = getattr(symbol, name, None)
if new_symbol is None:
return symbol, context[i:]
else:
symbol = new_symbol

return symbol, []

def _at_shutdown(self): def _at_shutdown(self):
"""Actions taken at shutdown by the kernel, called by python's atexit. """Actions taken at shutdown by the kernel, called by python's atexit.
""" """
Expand Down
8 changes: 3 additions & 5 deletions IPython/zmq/session.py
Expand Up @@ -558,11 +558,9 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
msg : dict msg : dict
The constructed message. The constructed message.
""" """

if not isinstance(stream, zmq.Socket):
if not isinstance(stream, (zmq.Socket, ZMQStream)): # ZMQStreams and dummy sockets do not support tracking.
raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) track = False
elif track and isinstance(stream, ZMQStream):
raise TypeError("ZMQStream cannot track messages")


if isinstance(msg_or_type, (Message, dict)): if isinstance(msg_or_type, (Message, dict)):
# We got a Message or message dict, not a msg_type so don't # We got a Message or message dict, not a msg_type so don't
Expand Down
4 changes: 2 additions & 2 deletions IPython/zmq/zmqshell.py
Expand Up @@ -42,7 +42,7 @@
from IPython.utils.jsonutil import json_clean, encode_images from IPython.utils.jsonutil import json_clean, encode_images
from IPython.utils.process import arg_split from IPython.utils.process import arg_split
from IPython.utils import py3compat from IPython.utils import py3compat
from IPython.utils.traitlets import Instance, Type, Dict, CBool, CBytes from IPython.utils.traitlets import Any, Instance, Type, Dict, CBool, CBytes
from IPython.utils.warn import warn, error from IPython.utils.warn import warn, error
from IPython.zmq.displayhook import ZMQShellDisplayHook from IPython.zmq.displayhook import ZMQShellDisplayHook
from IPython.zmq.datapub import ZMQDataPublisher from IPython.zmq.datapub import ZMQDataPublisher
Expand All @@ -57,7 +57,7 @@ class ZMQDisplayPublisher(DisplayPublisher):
"""A display publisher that publishes data using a ZeroMQ PUB socket.""" """A display publisher that publishes data using a ZeroMQ PUB socket."""


session = Instance(Session) session = Instance(Session)
pub_socket = Instance('zmq.Socket') pub_socket = Any()
parent_header = Dict({}) parent_header = Dict({})
topic = CBytes(b'displaypub') topic = CBytes(b'displaypub')


Expand Down

0 comments on commit 8d53a55

Please sign in to comment.