diff --git a/py/executor.py b/py/executor.py index b0f20e9410..a388efd191 100644 --- a/py/executor.py +++ b/py/executor.py @@ -11,7 +11,7 @@ import astor from stencila.schema.types import Parameter, CodeChunk, Article, Entity, CodeExpression, ConstantSchema, EnumSchema, \ BooleanSchema, NumberSchema, IntegerSchema, StringSchema, ArraySchema, TupleSchema, ImageObject, Datatable, \ - DatatableColumn + DatatableColumn, SchemaTypes, SoftwareSourceCode from stencila.schema.util import from_json, to_json try: @@ -58,58 +58,168 @@ def write(self, string: typing.Union[bytes, str]) -> int: return super(StdoutBuffer, self).buffer.write(string) -class DocumentParser: +class Function: + name: str + parameters: typing.List[Parameter] + returns: SchemaTypes + + +class Variable: + name: str + schema: typing.Optional[SchemaTypes] + + def __init__(self, name: str, schema: typing.Optional[SchemaTypes] = None): + self.name = name + self.schema = schema + + +class DocumentCompilationResult: + parameters: typing.List[Parameter] = [] + code: typing.List[ExecutableCode] = [] + declares: typing.List[typing.Union[Function, Variable]] = [] + imports: typing.List[str] = [] + + +class CodeChunkParseResult(typing.NamedTuple): + imports: typing.List[typing.Union[str, SoftwareSourceCode]] = [] + declares: typing.List[typing.Union[Function, Variable]] = [] + + +def annotation_name_to_schema(name: str) -> typing.Optional[SchemaTypes]: + if name == 'bool': + return BooleanSchema() + elif name == 'str': + return StringSchema() + elif name == 'int': + return IntegerSchema() + elif name == 'float': + return NumberSchema() + elif 'list' in name.lower(): + return ArraySchema() + elif 'tuple' in name.lower(): + return TupleSchema() + + return None + + +def parse_code_chunk(chunk: CodeChunk) -> CodeChunkParseResult: + imports: typing.List[str] = [] + declares: typing.List[typing.Union[Function, Variable]] = [] + + seen_vars: typing.Set[str] = set() + + for statement in ast.parse(chunk.text).body: + if isinstance(statement, ast.ImportFrom): + if statement.module not in imports: + imports.append(statement.module) + elif isinstance(statement, ast.Import): + for module_name in statement.names: + if module_name.name not in imports: + imports.append(module_name.name) + elif isinstance(statement, ast.FunctionDef): + f = Function() + f.parameters = [] + + for i, arg in enumerate(statement.args.args): + p = Parameter(arg.arg) + + if arg.annotation: + p.schema = annotation_name_to_schema(arg.annotation.id) + + default_index = len(statement.args.defaults) - len(statement.args.args) + i + # Only the last len(statement.args.defaults) can have defaults (since they must come after non-default + # parameters) + if default_index >= 0: + p.default = statement.args.defaults[default_index].value + p.required = False + else: + p.required = True + + f.parameters.append(p) + + declares.append(f) + elif isinstance(statement, (ast.Assign, ast.AnnAssign)): + if hasattr(statement, 'targets'): + targets = statement.targets + elif hasattr(statement, 'target'): + targets = [statement.target] + else: + raise TypeError('statement has no target or targets') + + for target in targets: + target_name = target.id + + if target_name not in seen_vars: + v = Variable(target_name) + + if hasattr(statement, 'annotation'): + # assignment with Type Annotation + v.schema = annotation_name_to_schema(statement.annotation.id) + + declares.append(v) + seen_vars.add(target_name) + return CodeChunkParseResult(imports, declares) + + +class DocumentCompiler: """Parse an executable document (`Article`) and cache references to its parameters and code nodes.""" + TARGET_LANGUAGE = 'python' + parameters: typing.List[Parameter] = [] code: typing.List[ExecutableCode] = [] + function_depth: int = 0 - def parse(self, source: Article) -> None: + def compile(self, source: Article) -> DocumentCompilationResult: # todo: this traverses the article twice. Make it less hard coded, maybe pass through a lookup table that maps # a found type to its destination - self.handle_item(source, Parameter, self.parameters, None) - self.handle_item(source, (CodeChunk, CodeExpression), self.code, {'language': 'python'}) - def handle_item(self, item: typing.Any, - search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]], - destination: typing.List[Entity], - attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None: + dcr = DocumentCompilationResult() + + self.handle_item(source, dcr) + return dcr + + def handle_item(self, item: typing.Any, compilation_result: DocumentCompilationResult) -> None: if isinstance(item, dict): - self.traverse_dict(item, search_type, destination, attr_match) + self.traverse_dict(item, compilation_result) elif isinstance(item, list): - self.traverse_list(item, search_type, destination, attr_match) + self.traverse_list(item, compilation_result) elif isinstance(item, Entity): - if isinstance(item, search_type): - can_add = True - if attr_match: - for k, v in attr_match.items(): - if getattr(item, k, None) != v: - can_add = False - break - if can_add: + if isinstance(item, (CodeChunk, CodeExpression)): + if item.language == self.TARGET_LANGUAGE: # Only add Python code + + if isinstance(item, CodeChunk): + parse_code_chunk(item) + + compilation_result.code.append(item) logger.debug('Adding {}'.format(type(item))) - destination.append(item) - self.traverse_dict(item.__dict__, search_type, destination, attr_match) - def traverse_dict(self, d: dict, - search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]], - destination: typing.List[Entity], - attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None: + elif isinstance(item, Parameter) and self.function_depth == 0: + compilation_result.parameters.append(item) + logger.debug('Adding {}'.format(type(item))) + + if isinstance(item, Function): + self.function_depth += 1 + + self.traverse_dict(item.__dict__, compilation_result) + + if isinstance(item, Function): + self.function_depth -= 1 + + def traverse_dict(self, d: dict, compilation_result: DocumentCompilationResult) -> None: for child in d.values(): - self.handle_item(child, search_type, destination, attr_match) + self.handle_item(child, compilation_result) - def traverse_list(self, l: typing.List, - search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]], - destination: typing.List[Entity], - attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None: + def traverse_list(self, l: typing.List, compilation_result: DocumentCompilationResult) -> None: for child in l: - self.handle_item(child, search_type, destination, attr_match) + self.handle_item(child, compilation_result) class Executor: """Execute a list of code blocks, maintaining its own `globals` scope for this execution run.""" globals: typing.Optional[typing.Dict[str, typing.Any]] + functions: typing.Dict[str, ast.FunctionDef] = {} def execute_code_chunk(self, chunk: CodeChunk, _locals: typing.Dict[str, typing.Any]) -> None: cc_outputs = [] @@ -117,6 +227,9 @@ def execute_code_chunk(self, chunk: CodeChunk, _locals: typing.Dict[str, typing. tree = ast.parse(chunk.text, 'exec') for statement in tree.body: + if isinstance(statement, ast.FunctionDef): + self.functions[statement.name] = statement + capture_result = False if isinstance(statement, ast.Expr): @@ -305,8 +418,8 @@ def execute_document(cli_args: typing.List[str]): if not isinstance(article, Article): raise TypeError('Decoded JSON was not an Article') - doc_parser = DocumentParser() - doc_parser.parse(article) + doc_parser = DocumentCompiler() + doc_parser.compile(article) e = Executor()