From e080f6b3f00703eb746bd2c65caf463d756f1ed8 Mon Sep 17 00:00:00 2001 From: Ben Shaw Date: Thu, 15 Aug 2019 12:54:17 +1200 Subject: [PATCH] feat: Converting matplotlib figures to ImageObjects during Py execution --- py/executor.py | 95 +++++++++++++++++++++++++++++++++++--- py/stencila/schema/util.py | 46 +----------------- 2 files changed, 90 insertions(+), 51 deletions(-) diff --git a/py/executor.py b/py/executor.py index 44fec24bf2..ba88ad6a29 100644 --- a/py/executor.py +++ b/py/executor.py @@ -1,4 +1,5 @@ import argparse +import base64 import json import logging import re @@ -8,9 +9,40 @@ from io import TextIOWrapper, BytesIO from stencila.schema.types import Parameter, CodeChunk, Article, Entity, CodeExpression, ConstantSchema, EnumSchema, \ - BooleanSchema, NumberSchema, IntegerSchema, StringSchema, ArraySchema, TupleSchema + BooleanSchema, NumberSchema, IntegerSchema, StringSchema, ArraySchema, TupleSchema, ImageObject, Datatable, \ + DatatableColumn from stencila.schema.util import from_json, to_json +try: + import matplotlib.figure + import matplotlib.artist + + MPLFigure = matplotlib.figure.Figure + MPLArtist = matplotlib.artist.Artist + mpl_available = True +except ImportError: + class MPLFigure: + pass + + + class MLPArtist: + pass + + + mpl_available = False + +try: + from pandas import DataFrame + import numpy + + pandas_available = True +except ImportError: + class DataFrame: + pass + + + pandas_available = False + logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -102,18 +134,18 @@ def execute_code_chunk(self, chunk: CodeChunk, _locals: typing.Dict[str, typing. result = _locals[RESULT_CAPTURE_VAR] del _locals[RESULT_CAPTURE_VAR] if result is not None: - cc_outputs.append(result) + cc_outputs.append(self.decode_output(result)) s.seek(0) - output = s.buffer.read() + std_out_output = s.buffer.read() - if output: - cc_outputs.append(output.decode('utf8')) + if std_out_output: + cc_outputs.append(std_out_output.decode('utf8')) chunk.outputs = cc_outputs def execute_code_expression(self, expression: CodeExpression, _locals: typing.Dict[str, typing.Any]) -> None: - expression.output = eval(expression.text, self.globals, _locals) + expression.output = self.decode_output(eval(expression.text, self.globals, _locals)) def execute(self, code: typing.List[ExecutableCode], parameter_values: typing.Dict[str, typing.Any]) -> None: self.globals = {} @@ -128,6 +160,57 @@ def execute(self, code: typing.List[ExecutableCode], parameter_values: typing.Di else: raise TypeError('Unknown Code node type found: {}'.format(c)) + @staticmethod + def value_is_mpl(value: typing.Any) -> bool: + if not mpl_available: + return False + + return isinstance(value, (MPLFigure, MPLArtist)) or ( + isinstance(value, list) and len(value) == 1 and isinstance(value[0], MPLArtist)) + + @staticmethod + def decode_mpl() -> ImageObject: + image = BytesIO() + matplotlib.pyplot.savefig(image, format='png') + src = 'data:image/png;base64,' + base64.encodebytes(image.getvalue()).decode() + return ImageObject(src) + + @staticmethod + def decode_dataframe(df: DataFrame) -> Datatable: + columns = [] + + for column_name in df.columns: + column = df[column_name] + values = column.tolist() + if column.dtype in (numpy.bool_, numpy.bool8): + schema = BooleanSchema() + values = [bool(row) for row in values] + elif column.dtype in (numpy.int8, numpy.int16, numpy.int32, numpy.int64): + schema = IntegerSchema() + values = [int(row) for row in values] + elif column.dtype in (numpy.float16, numpy.float32, numpy.float64): + schema = NumberSchema() + values = [float(row) for row in values] + elif column.dtype in (numpy.str_, numpy.unicode_,): + schema = StringSchema() + else: + schema = None + + columns.append( + DatatableColumn(column_name, values, schema=ArraySchema(items=schema)) + ) + + return Datatable(columns) + + def decode_output(self, output: typing.Any) -> typing.Any: + if self.value_is_mpl(output): + return self.decode_mpl() + + if isinstance(output, DataFrame): + return self.decode_dataframe(output) + + return output + class ParameterParser: """ diff --git a/py/stencila/schema/util.py b/py/stencila/schema/util.py index 38c70a3c84..982df90128 100644 --- a/py/stencila/schema/util.py +++ b/py/stencila/schema/util.py @@ -1,57 +1,13 @@ """Utility functions for schema to/from JSON.""" - import json import typing from . import types -from .types import Node, Entity, Datatable, DatatableColumn, BooleanSchema, IntegerSchema, NumberSchema, StringSchema, \ - ArraySchema - -try: - from pandas import DataFrame - import numpy - - pandas_available = True -except ImportError: - class DataFrame: - pass - - pandas_available = False - - -def data_frame_to_data_table(df: DataFrame) -> Datatable: - columns = [] - - for column_name in df.columns: - column = df[column_name] - values = column.tolist() - if column.dtype in (numpy.bool_, numpy.bool8): - schema = BooleanSchema() - values = [bool(row) for row in values] - elif column.dtype in (numpy.int8, numpy.int16, numpy.int32, numpy.int64): - schema = IntegerSchema() - values = [int(row) for row in values] - elif column.dtype in (numpy.float16, numpy.float32, numpy.float64): - schema = NumberSchema() - values = [float(row) for row in values] - elif column.dtype in (numpy.str_, numpy.unicode_,): - schema = StringSchema() - else: - schema = None - - columns.append( - DatatableColumn(column_name, values, schema=ArraySchema(items=schema)) - ) - - return Datatable(columns) +from .types import Node def to_dict(node: typing.Any) -> dict: """Convert an Entity node to a dictionary""" - if pandas_available: - if isinstance(node, DataFrame): - node = data_frame_to_data_table(node) - node_dict = { "type": node.__class__.__name__ }