Skip to content

Commit

Permalink
Merge pull request #75 from pypr/fix-jupyter-capture
Browse files Browse the repository at this point in the history
Fix issues with capturing stdio on notebooks.
  • Loading branch information
prabhuramachandran committed Jan 11, 2021
2 parents 6b569f0 + a115034 commit 47d2226
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
29 changes: 28 additions & 1 deletion compyle/capture_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@
from tempfile import mktemp


def get_ipython_capture():
try:
# This will work inside IPython but not outside it.
name = get_ipython().__class__.__name__
if name.startswith('ZMQ'):
from IPython.utils.capture import capture_output
return capture_output
else:
return None
except NameError:
return None


class CaptureStream(object):
"""A context manager which captures any errors on a given stream (like
sys.stderr). The stream is captured and the outputs can be used.
Expand Down Expand Up @@ -79,15 +92,29 @@ def __init__(self, streams=None):
streams = (sys.stdout, sys.stderr) if streams is None else streams
self.streams = streams
self.captures = [CaptureStream(x) for x in streams]
cap = get_ipython_capture()
if cap:
self.jcap = cap(stdout=True, stderr=True, display=True)
else:
self.jcap = None
self.joutput = None

def __enter__(self):
for capture in self.captures:
capture.__enter__()
if self.jcap:
self.joutput = self.jcap.__enter__()
return self

def __exit__(self, type, value, tb):
for capture in self.captures:
capture.__exit__(type, value, tb)
if self.jcap:
self.jcap.__exit__(type, value, tb)

def get_output(self):
return tuple(x.get_output() for x in self.captures)
out = list(x.get_output() for x in self.captures)
if self.joutput:
out[0] += self.joutput.stdout
out[1] += self.joutput.stderr
return out
16 changes: 4 additions & 12 deletions compyle/ext_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
else:
from distutils.extension import Extension

PY3 = sys.version_info.major > 2

# Package imports.
from .config import get_config # noqa: 402
from .capture_stream import CaptureMultipleStreams # noqa: 402
Expand Down Expand Up @@ -78,13 +76,6 @@ def get_md5(data):
return hashlib.md5(data.encode()).hexdigest()


def get_unicode(s):
if PY3:
return s
else:
return unicode(s)


def get_openmp_flags():
"""Return the OpenMP flags for the platform.
Expand Down Expand Up @@ -197,7 +188,7 @@ def _try_to_lock():
def _write_source(self, path):
if not exists(path):
with io.open(path, 'w', encoding='utf-8') as f:
f.write(get_unicode(self.code))
f.write(self.code)

def _setup_root(self, root):
if root is None:
Expand Down Expand Up @@ -279,8 +270,9 @@ def build(self, force=False):
except (CompileError, LinkError):
hline = "*"*80
print(hline + "\nERROR")
print(stream.get_output()[0])
print(stream.get_output()[1])
s_out = stream.get_output()
print(s_out[0])
print(s_out[1])
msg = "Compilation of code failed, please check "\
"error messages above."
print(hline + "\n" + msg)
Expand Down
4 changes: 2 additions & 2 deletions compyle/tests/test_ext_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import compyle.ext_module

from ..ext_module import (get_md5, ExtModule, get_ext_extension, get_unicode,
from ..ext_module import (get_md5, ExtModule, get_ext_extension,
get_config_file_opts, get_openmp_flags)


Expand All @@ -30,7 +30,7 @@ def _check_write_source(root):

def _side_effect(*args, **kw):
with io_open(*args, **kw) as fp:
fp.write(get_unicode("junk"))
fp.write("junk")
return orig_side_effect(*args, **kw)
m.side_effect = _side_effect

Expand Down

0 comments on commit 47d2226

Please sign in to comment.