Skip to content

Commit

Permalink
feat: Converting matplotlib figures to ImageObjects during Py execution
Browse files Browse the repository at this point in the history
  • Loading branch information
beneboy committed Sep 2, 2019
1 parent 39406e5 commit e080f6b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 51 deletions.
95 changes: 89 additions & 6 deletions py/executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import base64
import json
import logging
import re
Expand All @@ -8,9 +9,40 @@
from io import TextIOWrapper, BytesIO

from stencila.schema.types import Parameter, CodeChunk, Article, Entity, CodeExpression, ConstantSchema, EnumSchema, \
BooleanSchema, NumberSchema, IntegerSchema, StringSchema, ArraySchema, TupleSchema
BooleanSchema, NumberSchema, IntegerSchema, StringSchema, ArraySchema, TupleSchema, ImageObject, Datatable, \
DatatableColumn
from stencila.schema.util import from_json, to_json

try:
import matplotlib.figure
import matplotlib.artist

MPLFigure = matplotlib.figure.Figure
MPLArtist = matplotlib.artist.Artist
mpl_available = True
except ImportError:
class MPLFigure:
pass


class MLPArtist:
pass


mpl_available = False

try:
from pandas import DataFrame
import numpy

pandas_available = True
except ImportError:
class DataFrame:
pass


pandas_available = False

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down Expand Up @@ -102,18 +134,18 @@ def execute_code_chunk(self, chunk: CodeChunk, _locals: typing.Dict[str, typing.
result = _locals[RESULT_CAPTURE_VAR]
del _locals[RESULT_CAPTURE_VAR]
if result is not None:
cc_outputs.append(result)
cc_outputs.append(self.decode_output(result))

s.seek(0)
output = s.buffer.read()
std_out_output = s.buffer.read()

if output:
cc_outputs.append(output.decode('utf8'))
if std_out_output:
cc_outputs.append(std_out_output.decode('utf8'))

chunk.outputs = cc_outputs

def execute_code_expression(self, expression: CodeExpression, _locals: typing.Dict[str, typing.Any]) -> None:
expression.output = eval(expression.text, self.globals, _locals)
expression.output = self.decode_output(eval(expression.text, self.globals, _locals))

def execute(self, code: typing.List[ExecutableCode], parameter_values: typing.Dict[str, typing.Any]) -> None:
self.globals = {}
Expand All @@ -128,6 +160,57 @@ def execute(self, code: typing.List[ExecutableCode], parameter_values: typing.Di
else:
raise TypeError('Unknown Code node type found: {}'.format(c))

@staticmethod
def value_is_mpl(value: typing.Any) -> bool:
if not mpl_available:
return False

return isinstance(value, (MPLFigure, MPLArtist)) or (
isinstance(value, list) and len(value) == 1 and isinstance(value[0], MPLArtist))

@staticmethod
def decode_mpl() -> ImageObject:
image = BytesIO()
matplotlib.pyplot.savefig(image, format='png')
src = 'data:image/png;base64,' + base64.encodebytes(image.getvalue()).decode()
return ImageObject(src)

@staticmethod
def decode_dataframe(df: DataFrame) -> Datatable:
columns = []

for column_name in df.columns:
column = df[column_name]
values = column.tolist()
if column.dtype in (numpy.bool_, numpy.bool8):
schema = BooleanSchema()
values = [bool(row) for row in values]
elif column.dtype in (numpy.int8, numpy.int16, numpy.int32, numpy.int64):
schema = IntegerSchema()
values = [int(row) for row in values]
elif column.dtype in (numpy.float16, numpy.float32, numpy.float64):
schema = NumberSchema()
values = [float(row) for row in values]
elif column.dtype in (numpy.str_, numpy.unicode_,):
schema = StringSchema()
else:
schema = None

columns.append(
DatatableColumn(column_name, values, schema=ArraySchema(items=schema))
)

return Datatable(columns)

def decode_output(self, output: typing.Any) -> typing.Any:
if self.value_is_mpl(output):
return self.decode_mpl()

if isinstance(output, DataFrame):
return self.decode_dataframe(output)

return output


class ParameterParser:
"""
Expand Down
46 changes: 1 addition & 45 deletions py/stencila/schema/util.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,13 @@
"""Utility functions for schema to/from JSON."""

import json
import typing

from . import types
from .types import Node, Entity, Datatable, DatatableColumn, BooleanSchema, IntegerSchema, NumberSchema, StringSchema, \
ArraySchema

try:
from pandas import DataFrame
import numpy

pandas_available = True
except ImportError:
class DataFrame:
pass

pandas_available = False


def data_frame_to_data_table(df: DataFrame) -> Datatable:
columns = []

for column_name in df.columns:
column = df[column_name]
values = column.tolist()
if column.dtype in (numpy.bool_, numpy.bool8):
schema = BooleanSchema()
values = [bool(row) for row in values]
elif column.dtype in (numpy.int8, numpy.int16, numpy.int32, numpy.int64):
schema = IntegerSchema()
values = [int(row) for row in values]
elif column.dtype in (numpy.float16, numpy.float32, numpy.float64):
schema = NumberSchema()
values = [float(row) for row in values]
elif column.dtype in (numpy.str_, numpy.unicode_,):
schema = StringSchema()
else:
schema = None

columns.append(
DatatableColumn(column_name, values, schema=ArraySchema(items=schema))
)

return Datatable(columns)
from .types import Node


def to_dict(node: typing.Any) -> dict:
"""Convert an Entity node to a dictionary"""
if pandas_available:
if isinstance(node, DataFrame):
node = data_frame_to_data_table(node)

node_dict = {
"type": node.__class__.__name__
}
Expand Down

0 comments on commit e080f6b

Please sign in to comment.