Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Changes the way drivers handle parameters
Browse files Browse the repository at this point in the history
See #52

Note: this is backwards compatible. The changes are that
we've added the "inputs" parameter to the driver, as well as the
function_graph executor. This allows external inputs to the DAG
at runtime, not just construction time.
  • Loading branch information
elijahbenizzy committed Feb 6, 2022
1 parent c6cfeae commit 5041770
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 90 deletions.
68 changes: 52 additions & 16 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,52 @@ def __init__(self, config: Dict[str, Any], *modules: ModuleType, adapter: base.H
self.graph = graph.FunctionGraph(*modules, config=config, adapter=adapter)
self.adapter = adapter

def validate_inputs(self, user_nodes: Collection[node.Node], inputs: Dict[str, Any]):
"""Validates that inputs meet our expectations.
@staticmethod
def combine_and_validate_inputs(
config: Dict[str, Any],
inputs: Dict[str, Any]) -> typing.Tuple[Dict[str, Any], typing.Optional[str]]:
"""Combines runtime inputs and configs, validating that they are a mutually disjoint set.
Returns the error message if they aren't. Note the following:
1. Even if the values are the same duplicate inputs are not allowed
2. The result of the combined inputs when there are collisions is not guarenteed to have any stability.
Don't rely on it too much. This might produce an extra error, but this is the first one that the user
should deal with.
:param config: Config for the DAG
:param inputs: Inputs to the DAG
:return: A tuple consisting of (combined_inputs, error_message) if they both exist
"""

duplicated_inputs = [key for key in inputs if key in config]
error = None
if len(duplicated_inputs) > 0:
error = f'The following inputs are duplicated between the config and the runtime inputs. ' \
f'The graph does not know which to select, please choose which value you want, ' \
f'and specify in only one place: {duplicated_inputs}'
return {**config, **inputs}, error

def validate_inputs(self, user_nodes: Collection[node.Node], inputs: typing.Optional[Dict[str, Any]] = None):
"""Validates that inputs meet our expectations. This means that:
1. The runtime inputs don't clash with the graph's config
2. All expected graph inputs are provided, either in config or at runtime
:param user_nodes: The required nodes we need for computation.
:param inputs: the user inputs provided.
"""
# validate inputs
errors = []
if inputs is None:
inputs = {}
all_inputs, duplicated_error = Driver.combine_and_validate_inputs(self.graph.config, inputs)
if duplicated_error is not None:
errors.append(duplicated_error)
for user_node in user_nodes:
if user_node.name not in inputs:
if user_node.name not in all_inputs:
errors.append(f'Error: Required input {user_node.name} not provided '
f'for nodes: {[node.name for node in user_node.depended_on_by]}.')
elif (inputs[user_node.name] is not None
and not self.adapter.check_input_type(user_node.type, inputs[user_node.name])):
elif (all_inputs[user_node.name] is not None
and not self.adapter.check_input_type(user_node.type, all_inputs[user_node.name])):
errors.append(f'Error: Type requirement mismatch. Expected {user_node.name}:{user_node.type} '
f'got {inputs[user_node.name]} instead.')
f'got {all_inputs[user_node.name]} instead.')
if errors:
errors.sort()
error_str = f'{len(errors)} errors encountered:\n ' + '\n '.join(errors)
Expand All @@ -71,41 +101,47 @@ def validate_inputs(self, user_nodes: Collection[node.Node], inputs: Dict[str, A
def execute(self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False) -> pd.DataFrame:
display_graph: bool = False,
inputs: Dict[str, Any] = None,
) -> pd.DataFrame:
"""Executes computation.
:param final_vars: the final list of variables we want in the data frame.
:param overrides: the user defined input variables.
:param overrides: the user defined overrides.
:param display_graph: whether we want to display the graph being computed.
:param inputs: Runtime inputs to the DAG
:return: a data frame consisting of the variables requested.
"""
columns = self.raw_execute(final_vars, overrides, display_graph)
columns = self.raw_execute(final_vars, overrides, display_graph, inputs=inputs)
return self.adapter.build_result(**columns)

def raw_execute(self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False) -> Dict[str, Any]:
display_graph: bool = False,
inputs: Dict[str, Any] = None) -> Dict[str, Any]:
"""Raw execute function that does the meat of execute.
It does not try to stitch anything together. Thus allowing wrapper executes around this to shape the output
of the data.
:param final_vars:
:param overrides:
:param display_graph:
:param final_vars: Final variables to compute
:param overrides: Overrides to run.
:param inputs: Runtime inputs to the DAG
:param display_graph: Whether or not to display the graph when running it
:return:
"""
nodes, user_nodes = self.graph.get_required_functions(final_vars)
self.validate_inputs(user_nodes, self.graph.config) # TODO -- validate within the function graph itself

self.validate_inputs(user_nodes, inputs) # TODO -- validate within the function graph itself
if display_graph:
# TODO: fix hardcoded path.
try:
self.graph.display(nodes, user_nodes, output_file_path='test-output/execute.gv')
except ImportError as e:
logger.warning(f'Unable to import {e}', exc_info=True)
memoized_computation = dict() # memoized storage
self.graph.execute(nodes, memoized_computation, overrides)
self.graph.execute(nodes, memoized_computation, overrides, inputs)
columns = {c: memoized_computation[c] for c in final_vars} # only want request variables in df.
del memoized_computation # trying to cleanup some memory
return columns
Expand Down
71 changes: 0 additions & 71 deletions hamilton/experimental/decorators.py

This file was deleted.

8 changes: 6 additions & 2 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,16 @@ def dfs_traverse(node: node.Node, dependency_type: DependencyType = DependencyTy
def execute(self,
nodes: Collection[node.Node] = None,
computed: Dict[str, Any] = None,
overrides: Dict[str, Any] = None) -> Dict[str, Any]:
overrides: Dict[str, Any] = None,
inputs: Dict[str, Any] = None
) -> Dict[str, Any]:
if nodes is None:
nodes = self.get_nodes()
if inputs is None:
inputs = {}
return FunctionGraph.execute_static(
nodes=nodes,
inputs=self.config,
inputs={**inputs, **self.config},
adapter=self.executor,
computed=computed,
overrides=overrides,
Expand Down
2 changes: 2 additions & 0 deletions tests/resources/very_simple_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def b(a: int) -> int:
return a
30 changes: 29 additions & 1 deletion tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from hamilton.driver import Driver
import tests.resources.very_simple_dag


def test_driver_validate_input_types():
dr = Driver({'a': 1}, [])
dr = Driver({'a': 1})
results = dr.raw_execute(['a'])
assert results == {'a': 1}


def test_driver_validate_runtime_input_types():
dr = Driver({}, tests.resources.very_simple_dag)
results = dr.raw_execute(['b'], inputs={'a': 1})
assert results == {'b': 1}


def test_combine_inputs_no_collision():
"""Tests the combine_and_validate_inputs functionality when there are no collisions"""
combined, error = Driver.combine_and_validate_inputs({'a': 1}, {'b': 2})
assert combined == {'a': 1, 'b': 2}
assert error is None


def test_combine_inputs_collision():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys but not values"""
combined, error = Driver.combine_and_validate_inputs({'a': 1}, {'a': 2})
assert error is not None


def test_combine_inputs_collision_2():
combined, error = Driver.combine_and_validate_inputs({'a': 1}, {'a': 1})
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys but not values"""
assert error is not None

0 comments on commit 5041770

Please sign in to comment.