Skip to content

Commit

Permalink
fix fastai#13
Browse files Browse the repository at this point in the history
  • Loading branch information
seeM committed Jul 4, 2022
1 parent 8cf9541 commit 3a72de8
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 23 deletions.
37 changes: 25 additions & 12 deletions execnb/shell.py
Expand Up @@ -5,10 +5,12 @@
from fastcore.basics import *
from fastcore.imports import *
from fastcore.script import call_parse
from fastcore.test import *

from IPython.core.interactiveshell import InteractiveShell
from IPython.core.displayhook import DisplayHook
from IPython.core.displaypub import DisplayPublisher
from base64 import b64encode
from io import StringIO

from .fastshell import FastInteractiveShell
Expand Down Expand Up @@ -37,9 +39,19 @@ def publish(self, data, metadata=None, **kwargs): self.shell._add_out(data, meta
# %% ../nbs/02_shell.ipynb 5
# These are the standard notebook formats for exception and stream data (e.g stdout)
def _out_exc(ename, evalue, traceback): return dict(ename=str(ename), evalue=str(evalue), output_type='error', traceback=traceback)
def _out_stream(text): return dict(name='stdout', output_type='stream', text=text.splitlines(False))

# %% ../nbs/02_shell.ipynb 7
def _out_stream(text): return dict(name='stdout', output_type='stream', text=text.splitlines(True))

# %% ../nbs/02_shell.ipynb 6
def _format_mimedata(k, v):
"Format mime-type keyed data consistently with Jupyter"
if k.startswith('text/'): return v.splitlines(True)
if k.startswith('image/') and isinstance(v, bytes):
print('*******')
v = b64encode(v).decode()
return v+'\n' if not v.endswith('\n') else v
return v

# %% ../nbs/02_shell.ipynb 8
class CaptureShell(FastInteractiveShell):
"Execute the IPython/Jupyter source code"
def __init__(self,
Expand Down Expand Up @@ -72,11 +84,12 @@ def _showtraceback(self, etype, evalue, stb: str):
self.out.append(_out_exc(etype, evalue, stb))
self.exc = (etype, evalue, '\n'.join(stb))

def _add_out(self, data, meta, typ='execute_result', **kwargs): self.out.append(dict(data=data, metadata=meta, output_type=typ, **kwargs))
def _add_out(self, data, meta, typ='execute_result', **kwargs):
fd = {k:_format_mimedata(k,v) for k,v in data.items()}
self.out.append(dict(data=fd, metadata=meta, output_type=typ, **kwargs))

def _add_exec(self, result, meta, typ='execute_result'):
fd = {k:v.splitlines(True) for k,v in result.items()}
self._add_out(fd, meta, execution_count=self.count)
self._add_out(result, meta, execution_count=self.count)
self.count += 1

def _result(self, result):
Expand All @@ -87,7 +100,7 @@ def _stream(self, std):
text = std.getvalue()
if text: self.out.append(_out_stream(text))

# %% ../nbs/02_shell.ipynb 10
# %% ../nbs/02_shell.ipynb 11
@patch
def run(self:CaptureShell,
code:str, # Python/IPython code to run
Expand All @@ -105,7 +118,7 @@ def run(self:CaptureShell,
self._stream(stdout)
return [*self.out]

# %% ../nbs/02_shell.ipynb 19
# %% ../nbs/02_shell.ipynb 20
@patch
def cell(self:CaptureShell, cell, stdout=True, stderr=True):
"Run `cell`, skipping if not code, and store outputs back in cell"
Expand All @@ -117,7 +130,7 @@ def cell(self:CaptureShell, cell, stdout=True, stderr=True):
for o in outs:
if 'execution_count' in o: cell['execution_count'] = o['execution_count']

# %% ../nbs/02_shell.ipynb 23
# %% ../nbs/02_shell.ipynb 24
def _false(o): return False

@patch
Expand All @@ -137,7 +150,7 @@ def run_all(self:CaptureShell,
postproc(cell)
if self.exc and exc_stop: raise self.exc[1] from None

# %% ../nbs/02_shell.ipynb 37
# %% ../nbs/02_shell.ipynb 38
@patch
def execute(self:CaptureShell,
src:str|Path, # Notebook path to read from
Expand All @@ -158,7 +171,7 @@ def execute(self:CaptureShell,
inject_code=inject_code, inject_idx=inject_idx)
if dest: write_nb(nb, dest)

# %% ../nbs/02_shell.ipynb 40
# %% ../nbs/02_shell.ipynb 41
@patch
def prettytb(self:CaptureShell,
fname:str|Path=None): # filename to print alongside the traceback
Expand All @@ -170,7 +183,7 @@ def prettytb(self:CaptureShell,
fname_str = f' in {fname}' if fname else ''
return f"{type(self.exc[1]).__name__}{fname_str}:\n{_fence}\n{cell_str}\n"

# %% ../nbs/02_shell.ipynb 52
# %% ../nbs/02_shell.ipynb 56
@call_parse
def exec_nb(
src:str, # Notebook path to read from
Expand Down
84 changes: 73 additions & 11 deletions nbs/02_shell.ipynb
Expand Up @@ -30,10 +30,12 @@
"from fastcore.basics import *\n",
"from fastcore.imports import *\n",
"from fastcore.script import call_parse\n",
"from fastcore.test import *\n",
"\n",
"from IPython.core.interactiveshell import InteractiveShell\n",
"from IPython.core.displayhook import DisplayHook\n",
"from IPython.core.displaypub import DisplayPublisher\n",
"from base64 import b64encode\n",
"from io import StringIO\n",
"\n",
"from execnb.fastshell import FastInteractiveShell\n",
Expand Down Expand Up @@ -73,7 +75,24 @@
"#|export\n",
"# These are the standard notebook formats for exception and stream data (e.g stdout)\n",
"def _out_exc(ename, evalue, traceback): return dict(ename=str(ename), evalue=str(evalue), output_type='error', traceback=traceback)\n",
"def _out_stream(text): return dict(name='stdout', output_type='stream', text=text.splitlines(False))"
"def _out_stream(text): return dict(name='stdout', output_type='stream', text=text.splitlines(True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"def _format_mimedata(k, v):\n",
" \"Format mime-type keyed data consistently with Jupyter\"\n",
" if k.startswith('text/'): return v.splitlines(True)\n",
" if k.startswith('image/') and isinstance(v, bytes):\n",
" print('*******')\n",
" v = b64encode(v).decode()\n",
" return v+'\\n' if not v.endswith('\\n') else v\n",
" return v"
]
},
{
Expand Down Expand Up @@ -122,11 +141,12 @@
" self.out.append(_out_exc(etype, evalue, stb))\n",
" self.exc = (etype, evalue, '\\n'.join(stb))\n",
"\n",
" def _add_out(self, data, meta, typ='execute_result', **kwargs): self.out.append(dict(data=data, metadata=meta, output_type=typ, **kwargs))\n",
" def _add_out(self, data, meta, typ='execute_result', **kwargs):\n",
" fd = {k:_format_mimedata(k,v) for k,v in data.items()}\n",
" self.out.append(dict(data=fd, metadata=meta, output_type=typ, **kwargs))\n",
"\n",
" def _add_exec(self, result, meta, typ='execute_result'):\n",
" fd = {k:v.splitlines(True) for k,v in result.items()}\n",
" self._add_out(fd, meta, execution_count=self.count)\n",
" self._add_out(result, meta, execution_count=self.count)\n",
" self.count += 1\n",
"\n",
" def _result(self, result):\n",
Expand Down Expand Up @@ -187,7 +207,7 @@
{
"data": {
"text/plain": [
"[{'name': 'stdout', 'output_type': 'stream', 'text': ['1']}]"
"[{'name': 'stdout', 'output_type': 'stream', 'text': ['1\\n']}]"
]
},
"execution_count": null,
Expand Down Expand Up @@ -220,8 +240,8 @@
" 'execution_count': 1},\n",
" {'name': 'stdout',\n",
" 'output_type': 'stream',\n",
" 'text': ['CPU times: user 3 us, sys: 1 us, total: 4 us',\n",
" 'Wall time: 6.2 us']}]"
" 'text': ['CPU times: user 3 us, sys: 2 us, total: 5 us\\n',\n",
" 'Wall time: 7.15 us\\n']}]"
]
},
"execution_count": null,
Expand Down Expand Up @@ -382,7 +402,7 @@
" 'metadata': {},\n",
" 'output_type': 'execute_result',\n",
" 'execution_count': 2},\n",
" {'name': 'stdout', 'output_type': 'stream', 'text': ['1']}]"
" {'name': 'stdout', 'output_type': 'stream', 'text': ['1\\n']}]"
]
},
"execution_count": null,
Expand Down Expand Up @@ -608,7 +628,7 @@
" 'metadata': {},\n",
" 'output_type': 'execute_result',\n",
" 'execution_count': 10},\n",
" {'name': 'stdout', 'output_type': 'stream', 'text': ['1']}]"
" {'name': 'stdout', 'output_type': 'stream', 'text': ['1\\n']}]"
]
},
"execution_count": null,
Expand Down Expand Up @@ -671,7 +691,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[{'data': {'text/plain': ['2']}, 'execution_count': 2, 'metadata': {}, 'output_type': 'execute_result'}, {'name': 'stdout', 'output_type': 'stream', 'text': ['1']}]\n"
"[{'data': {'text/plain': ['2']}, 'execution_count': 2, 'metadata': {}, 'output_type': 'execute_result'}, {'name': 'stdout', 'output_type': 'stream', 'text': ['1\\n']}]\n"
]
}
],
Expand Down Expand Up @@ -766,6 +786,48 @@
"assert any('image/png' in o['data'] for o in res),'Image not captured'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|hide\n",
"#Streams are split on and keep newlines\n",
"res = CaptureShell().run(r\"print('a\\nb'); print('c\\n\\n'); print('d')\")\n",
"test_eq(res[0]['text'], ['a\\n', 'b\\n', 'c\\n', '\\n', '\\n', 'd\\n'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|hide\n",
"#Text mime data are split on and keep newlines\n",
"res = CaptureShell().run(r\"from IPython.display import Markdown; display(Markdown('a\\nb'))\")\n",
"test_eq(res[0]['data']['text/markdown'], ['a\\n', 'b'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|hide\n",
"#Binary image mime data are base64-encoded and end in a single `\\n`\n",
"from PIL import Image\n",
"\n",
"def _pil2b64(im): return b64encode(im._repr_png_()).decode()+'\\n'\n",
"im = Image.new('RGB', (3,3), 'red')\n",
"imb64 = _pil2b64(im)\n",
"\n",
"res = CaptureShell().run(\"from PIL import Image; Image.new('RGB', (3,3), 'red')\")\n",
"test_eq(res[0]['data']['image/png'], imb64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -825,7 +887,7 @@
" 'execution_count': None,\n",
" 'id': 'ea528db5',\n",
" 'metadata': {},\n",
" 'outputs': [{'name': 'stdout', 'output_type': 'stream', 'text': ['2']}],\n",
" 'outputs': [{'name': 'stdout', 'output_type': 'stream', 'text': ['2\\n']}],\n",
" 'source': 'print(a)',\n",
" 'idx_': 1}]"
]
Expand Down

0 comments on commit 3a72de8

Please sign in to comment.