Skip to content

Commit

Permalink
chore: Split Python document parsing out from Executor class
Browse files Browse the repository at this point in the history
  • Loading branch information
beneboy committed Sep 2, 2019
1 parent 02fd32b commit a9ac2a2
Showing 1 changed file with 68 additions and 44 deletions.
112 changes: 68 additions & 44 deletions py/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
ASSIGNMENT_RE = re.compile(r'^[_a-z][a-z0-9]+\s*=')
IMPORT_RE = re.compile(r'^(import|from) ')

ExecutableCode = typing.Union[CodeChunk, CodeExpression]


class StdoutBuffer(TextIOWrapper):
def write(self, string: typing.Union[bytes, str]) -> int:
Expand All @@ -27,17 +29,59 @@ def write(self, string: typing.Union[bytes, str]) -> int:
return super(StdoutBuffer, self).buffer.write(string)


class Executor:
class DocumentParser:
"""Parse an executable document (`Article`) and cache references to its parameters and code nodes."""

parameters: typing.List[Parameter] = []
code: typing.List[typing.Union[CodeChunk, CodeExpression]] = []
globals: typing.Optional[typing.Dict[str, typing.Any]]
code: typing.List[ExecutableCode] = []

def parse(self, source: Article) -> None:
# todo: this traverses the article twice. Make it less hard coded, maybe pass through a lookup table that maps
# a found type to its destination
self.handle_item(source, Parameter, self.parameters, None)
self.handle_item(source, (CodeChunk, CodeExpression), self.code, {'language': 'python'})

def handle_item(self, item: typing.Any,
search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]],
destination: typing.List[Entity],
attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
if isinstance(item, dict):
self.traverse_dict(item, search_type, destination, attr_match)
elif isinstance(item, list):
self.traverse_list(item, search_type, destination, attr_match)
elif isinstance(item, Entity):
if isinstance(item, search_type):
can_add = True
if attr_match:
for k, v in attr_match.items():
if getattr(item, k, None) != v:
can_add = False
break
if can_add:
logger.debug('Adding {}'.format(type(item)))
destination.append(item)
self.traverse_dict(item.__dict__, search_type, destination, attr_match)

def traverse_dict(self, d: dict,
search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]],
destination: typing.List[Entity],
attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
for child in d.values():
self.handle_item(child, search_type, destination, attr_match)

def traverse_list(self, l: typing.List,
search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]],
destination: typing.List[Entity],
attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
for child in l:
self.handle_item(child, search_type, destination, attr_match)


class Executor:
"""Execute a list of code blocks, maintaining its own `globals` scope for this execution run."""

globals: typing.Optional[typing.Dict[str, typing.Any]]

def execute_code_chunk(self, chunk: CodeChunk, _locals: typing.Dict[str, typing.Any]) -> None:
cc_outputs = []
for statement in chunk.text.split('\n'):
Expand Down Expand Up @@ -71,56 +115,26 @@ def execute_code_chunk(self, chunk: CodeChunk, _locals: typing.Dict[str, typing.
def execute_code_expression(self, expression: CodeExpression, _locals: typing.Dict[str, typing.Any]) -> None:
expression.output = eval(expression.text, self.globals, _locals)

def execute(self, parameter_values: typing.Dict[str, typing.Any]) -> None:
def execute(self, code: typing.List[ExecutableCode], parameter_values: typing.Dict[str, typing.Any]) -> None:
self.globals = {}

_locals = parameter_values.copy()

for c in self.code:
for c in code:
if isinstance(c, CodeChunk):
self.execute_code_chunk(c, _locals)
elif isinstance(c, CodeExpression):
self.execute_code_expression(c, _locals)
else:
raise TypeError('Unknown Code node type found: {}'.format(c))

def handle_item(self, item: typing.Any,
search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]],
destination: typing.List[Entity],
attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
if isinstance(item, dict):
self.traverse_dict(item, search_type, destination, attr_match)
elif isinstance(item, list):
self.traverse_list(item, search_type, destination, attr_match)
elif isinstance(item, Entity):
if isinstance(item, search_type):
can_add = True
if attr_match:
for k, v in attr_match.items():
if getattr(item, k, None) != v:
can_add = False
break
if can_add:
logger.debug('Adding {}'.format(type(item)))
destination.append(item)
self.traverse_dict(item.__dict__, search_type, destination, attr_match)

def traverse_dict(self, d: dict,
search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]],
destination: typing.List[Entity],
attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
for child in d.values():
self.handle_item(child, search_type, destination, attr_match)

def traverse_list(self, l: typing.List,
search_type: typing.Union[typing.Type[Entity], typing.Iterable[typing.Type[Entity]]],
destination: typing.List[Entity],
attr_match: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
for child in l:
self.handle_item(child, search_type, destination, attr_match)


class ParameterParser:
"""
Parse parameters that the document requires, from the command line.
The `ArgumentParser` will fail if any required parameters are not passed in.
"""
parameters: typing.Dict[str, Parameter]
parameter_values: typing.Dict[str, typing.Any]

Expand All @@ -137,7 +151,8 @@ def parse_cli_args(self, cli_args: typing.List[str]) -> None:

for param in self.parameters.values():
if not isinstance(param.schema, ConstantSchema):
param_parser.add_argument('--' + param.name, dest=param.name, required=param.default is None)
param_parser.add_argument('--' + param.name, dest=param.name,
required=self.get_parameter_default(param) is None)

args, _ = param_parser.parse_known_args(cli_args)

Expand All @@ -150,6 +165,13 @@ def parse_cli_args(self, cli_args: typing.List[str]) -> None:
else:
self.parameter_values[param_name] = self.deserialize_parameter(self.parameters[param_name], cli_value)

@staticmethod
def get_parameter_default(parameter: Parameter) -> typing.Any:
if isinstance(parameter.schema, ConstantSchema):
return parameter.schema.value or parameter.default

return parameter.default

@staticmethod
def deserialize_parameter(parameter: Parameter, value: typing.Any) -> typing.Any:
# Lots of TODOs here, might not care as passing this off to encoda soon
Expand Down Expand Up @@ -199,13 +221,15 @@ def execute_document(cli_args: typing.List[str]):
if not isinstance(article, Article):
raise TypeError('Decoded JSON was not an Article')

doc_parser = DocumentParser()
doc_parser.parse(article)

e = Executor()
e.parse(article)

pp = ParameterParser(e.parameters)
pp = ParameterParser(doc_parser.parameters)
pp.parse_cli_args(cli_args)

e.execute(pp.parameter_values)
e.execute(doc_parser.code, pp.parameter_values)

if args.output_file == '-':
sys.stdout.write(to_json(article))
Expand Down

0 comments on commit a9ac2a2

Please sign in to comment.