From 95bf960b61036979ed9065c4539304236db71de4 Mon Sep 17 00:00:00 2001 From: skrawczyk Date: Sun, 6 Feb 2022 22:36:13 -0800 Subject: [PATCH 1/5] Adds simple case to help motivate @extract_outputs If someone wanted to use Hamilton to model a modeling dataflow, they would struggle. Need a new decorator to handle extracting outputs from functions that return multiple things and aren't a dataframe. --- .../scikit-learn/model_logic.py | 42 +++++++++++++++++++ .../scikit-learn/requirements.txt | 2 + examples/model_examples/scikit-learn/run.py | 25 +++++++++++ 3 files changed, 69 insertions(+) create mode 100644 examples/model_examples/scikit-learn/model_logic.py create mode 100644 examples/model_examples/scikit-learn/requirements.txt create mode 100644 examples/model_examples/scikit-learn/run.py diff --git a/examples/model_examples/scikit-learn/model_logic.py b/examples/model_examples/scikit-learn/model_logic.py new file mode 100644 index 00000000..348835c6 --- /dev/null +++ b/examples/model_examples/scikit-learn/model_logic.py @@ -0,0 +1,42 @@ +from typing import Dict + +import numpy as np +import pandas as pd +from sklearn import datasets, svm, metrics +from sklearn.model_selection import train_test_split + +from hamilton import function_modifiers + + +def flattened_digits(digit_data: np.ndarray) -> np.ndarray: + return digit_data.reshape((len(digit_data), -1)) + + +@function_modifiers.config.when(clf='svm') +def prefit_clf__svm(gamma: float = 0.001) -> svm.SVC: + return svm.SVC(gamma=gamma) + + +@function_modifiers.extract_columns(*['X_train', 'X_test', 'y_train', 'y_test']) +def train_test_split_func(flattened_digits: np.ndarray, digits_targets: np.ndarray, test_size_fraction: float = 0.5, + shuffle_train_test_split: bool = False) -> Dict[str, np.ndarray]: + X_train, X_test, y_train, y_test = train_test_split( + flattened_digits, digits_targets, test_size=test_size_fraction, shuffle=shuffle_train_test_split + ) + return { + 'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test + } + +def fit_clf(prefit_clf: svm.SVC, X_train: np.ndarray, y_train: np.ndarray) -> svm.SVC: + """This mutates the prefit_clf""" + prefit_clf.fit(X_train, y_train) + return prefit_clf + +def predict_output(fit_clf: svm.SVC, X_test: np.ndarray) -> np.ndarray: + return fit_clf.predict(X_test) + +def classification_report(predict_output: np.ndarray, y_test: np.ndarray) -> str: + return metrics.classification_report(y_test, predict_output) + +# def confusion_matrix(predict_output: np.ndarray, y_test: np.ndarray) -> str: + # return metrics.ConfusionMatrixDisplay.from_predictions(y_test, predict_output) diff --git a/examples/model_examples/scikit-learn/requirements.txt b/examples/model_examples/scikit-learn/requirements.txt new file mode 100644 index 00000000..3f69ad5c --- /dev/null +++ b/examples/model_examples/scikit-learn/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +sf-hamilton diff --git a/examples/model_examples/scikit-learn/run.py b/examples/model_examples/scikit-learn/run.py new file mode 100644 index 00000000..a7a88f50 --- /dev/null +++ b/examples/model_examples/scikit-learn/run.py @@ -0,0 +1,25 @@ +from sklearn import datasets + +from hamilton import driver +from hamilton import base + +import model_logic + +digits = datasets.load_digits() + +config_or_data = { + 'digit_data': digits.images, # TODO: move this to loaders.py + 'digits_targets': digits.target, # TODO: move this to loaders.py + 'gamma': 0.001, + 'test_size_fraction': 0.5, + 'shuffle_train_test_split': False, + 'clf': 'svm' +} + +adapter = base.SimplePythonGraphAdapter(base.DictResult()) + +dr = driver.Driver(config_or_data, model_logic, adapter=adapter) + +results = dr.execute(['classification_report']) +for k, v in results.items(): + print(v) From 1f4f6deb3536a71b363d877d75e418c7dba78123 Mon Sep 17 00:00:00 2001 From: skrawczyk Date: Tue, 8 Feb 2022 23:34:19 -0800 Subject: [PATCH 2/5] Adds extract_fields decorator for operating over dicts The API to use it looks like this: ```python @function_modifiers.extract_fields( {'X_train': np.ndarray, 'X_test': np.ndarray, 'y_train': np.ndarray, 'y_test': np.ndarray}) def train_test_split_func( ... , ... ) -> dict ``` I decided to go with a straight dict of `field_name` to `field_type` because that seemed the simplest thing to define. Note, we use the documentation for the original function, rather than enabling individual doc strings for the types. I think this suffices for now. To support TypedDict, I didn't want to have to import typing_extensions to handle it. Also you can't inline define a TypedDict class, so it would be more verbose which is less that ideal. We can always add TypedDict support later. Also punted on `Tuple` support -- that might be another decorator... --- hamilton/function_modifiers.py | 92 ++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 93 insertions(+) diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index ba3752db..9989d940 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -1,7 +1,9 @@ import functools import functools import inspect +import typing from typing import Dict, Callable, Collection, Tuple, Union, Any, Type +import typing_inspect import pandas as pd @@ -190,6 +192,96 @@ def extractor_fn(column_to_extract: str = column, **kwargs) -> pd.Series: # avo return output_nodes + +class extract_fields(NodeExpander): + """Extracts fields from a dictionary of output.""" + + def __init__(self, fields: dict, fill_with: Any = None): + """Constructor for a modifier that expands a single function into the following nodes: + - n functions, each of which take in the original dict and output a specific field + - 1 function that outputs the original dict + + :param fields: Fields to extract. A dict of 'field_name' -> 'field_type'. + :param fill_with: If you want to extract a field that doesn't exist, do you want to fill it with a default value? + Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a field value. + """ + if not fields: + raise InvalidDecoratorException('Error an empty dict, or no dict, passed to extract_fields decorator.') + elif not isinstance(fields, dict): + raise InvalidDecoratorException(f'Error, please pass in a dict or typed dict, not {type(fields)}') + else: + errors = [] + for field, field_type in fields.items(): + if not isinstance(field, str): + errors.append(f'{field} is not a string. All keys must be strings.') + if not isinstance(field_type, type): + errors.append(f'{field} does not declare a type. Instead it passes {field_type}.') + + if errors: + raise InvalidDecoratorException(f'Error, found these {errors}. ' + f'Please pass in a dict of string to types. ') + self.fields = fields + self.fill_with = fill_with + + def validate(self, fn: Callable): + """A function is invalid if it is not annotated with a dict or typing.Dict return type. + + :param fn: Function to validate. + :raises: InvalidDecoratorException If the function is not annotated with a dict or typing.Dict type as output. + """ + output_type = inspect.signature(fn).return_annotation + if typing_inspect.is_generic_type(output_type): + base = typing_inspect.get_origin(output_type) + if base == dict or base == typing.Dict: + pass + else: + raise InvalidDecoratorException( + f'For extracting fields, output type must be a dict or typing.Dict, not: {output_type}') + elif output_type == dict: + pass + else: + raise InvalidDecoratorException( + f'For extracting fields, output type must be a dict or typing.Dict, not: {output_type}') + + def expand_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]: + """For each field to extract, output a node that extracts that field. Also, output the original TypedDict + generator. + + :param node_: + :param config: + :param fn: Function to extract columns from. Must output a dataframe. + :return: A collection of nodes -- + one for the original dataframe generator, and another for each column to extract. + """ + fn = node_.callable + base_doc = node_.documentation + + @functools.wraps(fn) + def dict_generator(*args, **kwargs): + dict_generated = fn(*args, **kwargs) + if self.fill_with is not None: + for field in self.fields: + if field not in dict_generated: + dict_generated[field] = self.fill_with + return dict_generated + + output_nodes = [node.Node(node_.name, typ=dict, doc_string=base_doc, callabl=dict_generator)] + + for field, field_type in self.fields.items(): + doc_string = base_doc # default doc string of base function. + + def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: # avoiding problems with closures + dt = kwargs[node_.name] + if field_to_extract not in dt: + raise InvalidDecoratorException(f'No such field: {field_to_extract} produced by {node_.name}. ' + f'It only produced {list(dt.keys())}') + return kwargs[node_.name][field_to_extract] + + output_nodes.append( + node.Node(field, field_type, doc_string, extractor_fn, input_types={node_.name: dict})) + return output_nodes + + # the following are empty functions that we can compare against to ensure that @does uses an empty function def _empty_function(): pass diff --git a/requirements.txt b/requirements.txt index 45785f6c..f1314b21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ dataclasses; python_version < '3.7' numpy pandas +typing_inspect From bb896ad57928e4bfa704e4ad886ec3e6d2f5d8b1 Mon Sep 17 00:00:00 2001 From: skrawczyk Date: Tue, 8 Feb 2022 23:41:27 -0800 Subject: [PATCH 3/5] Refactors model example to show how to be generic This example is interesting because it shows how one might build a "bank" of hamilton functions that do some generic modeling -- while keeping it generic so that adding new contexts/ running it with new models, results in a small amount of work. The key things to get this to work are: - different python modules to load data. They have to output what's required to link with the my_train_evaluate_logic functions. - config & @config.when to add the correct model function to the dataflow. So if you want to switch between model types: easy -- change config. So if you want to switch fitting models on different data: easy -- change the data loading module. --- .../scikit-learn/digit_loader.py | 24 ++++++ .../scikit-learn/iris_loader.py | 23 +++++ .../scikit-learn/model_logic.py | 42 --------- .../scikit-learn/my_train_evaluate_logic.py | 86 +++++++++++++++++++ examples/model_examples/scikit-learn/run.py | 71 +++++++++++---- 5 files changed, 188 insertions(+), 58 deletions(-) create mode 100644 examples/model_examples/scikit-learn/digit_loader.py create mode 100644 examples/model_examples/scikit-learn/iris_loader.py delete mode 100644 examples/model_examples/scikit-learn/model_logic.py create mode 100644 examples/model_examples/scikit-learn/my_train_evaluate_logic.py diff --git a/examples/model_examples/scikit-learn/digit_loader.py b/examples/model_examples/scikit-learn/digit_loader.py new file mode 100644 index 00000000..2db1a833 --- /dev/null +++ b/examples/model_examples/scikit-learn/digit_loader.py @@ -0,0 +1,24 @@ +import numpy as np + +from sklearn import datasets +from sklearn import utils + +""" +Module to load digit data. +""" + +def digit_data() -> utils.Bunch: + return datasets.load_digits() + + +def target(digit_data: utils.Bunch) -> np.ndarray: + return digit_data.target + + +def target_names(digit_data: utils.Bunch) -> np.ndarray: + return digit_data.target_names + + +def feature_matrix(digit_data: utils.Bunch) -> np.ndarray: + # return digit_data.images.reshape((len(digit_data), -1)) + return digit_data.data diff --git a/examples/model_examples/scikit-learn/iris_loader.py b/examples/model_examples/scikit-learn/iris_loader.py new file mode 100644 index 00000000..f82639fb --- /dev/null +++ b/examples/model_examples/scikit-learn/iris_loader.py @@ -0,0 +1,23 @@ +import numpy as np + +from sklearn import datasets +from sklearn import utils + +""" +Module to load iris data. +""" + +def iris_data() -> utils.Bunch: + return datasets.load_iris() + + +def target(iris_data: utils.Bunch) -> np.ndarray: + return iris_data.target + + +def target_names(iris_data: utils.Bunch) -> np.ndarray: + return iris_data.target_names + + +def feature_matrix(iris_data: utils.Bunch) -> np.ndarray: + return iris_data.data diff --git a/examples/model_examples/scikit-learn/model_logic.py b/examples/model_examples/scikit-learn/model_logic.py deleted file mode 100644 index 348835c6..00000000 --- a/examples/model_examples/scikit-learn/model_logic.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Dict - -import numpy as np -import pandas as pd -from sklearn import datasets, svm, metrics -from sklearn.model_selection import train_test_split - -from hamilton import function_modifiers - - -def flattened_digits(digit_data: np.ndarray) -> np.ndarray: - return digit_data.reshape((len(digit_data), -1)) - - -@function_modifiers.config.when(clf='svm') -def prefit_clf__svm(gamma: float = 0.001) -> svm.SVC: - return svm.SVC(gamma=gamma) - - -@function_modifiers.extract_columns(*['X_train', 'X_test', 'y_train', 'y_test']) -def train_test_split_func(flattened_digits: np.ndarray, digits_targets: np.ndarray, test_size_fraction: float = 0.5, - shuffle_train_test_split: bool = False) -> Dict[str, np.ndarray]: - X_train, X_test, y_train, y_test = train_test_split( - flattened_digits, digits_targets, test_size=test_size_fraction, shuffle=shuffle_train_test_split - ) - return { - 'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test - } - -def fit_clf(prefit_clf: svm.SVC, X_train: np.ndarray, y_train: np.ndarray) -> svm.SVC: - """This mutates the prefit_clf""" - prefit_clf.fit(X_train, y_train) - return prefit_clf - -def predict_output(fit_clf: svm.SVC, X_test: np.ndarray) -> np.ndarray: - return fit_clf.predict(X_test) - -def classification_report(predict_output: np.ndarray, y_test: np.ndarray) -> str: - return metrics.classification_report(y_test, predict_output) - -# def confusion_matrix(predict_output: np.ndarray, y_test: np.ndarray) -> str: - # return metrics.ConfusionMatrixDisplay.from_predictions(y_test, predict_output) diff --git a/examples/model_examples/scikit-learn/my_train_evaluate_logic.py b/examples/model_examples/scikit-learn/my_train_evaluate_logic.py new file mode 100644 index 00000000..97628f0e --- /dev/null +++ b/examples/model_examples/scikit-learn/my_train_evaluate_logic.py @@ -0,0 +1,86 @@ +from typing import Dict, List + +import numpy as np +from sklearn import base +from sklearn import metrics +from sklearn import svm +from sklearn import linear_model +from sklearn.model_selection import train_test_split + +from hamilton import function_modifiers + + + +@function_modifiers.config.when(clf='svm') +def prefit_clf__svm(gamma: float = 0.001) -> base.ClassifierMixin: + """Returns an unfitted SVM classifier object. + + :param gamma: ... + :return: + """ + return svm.SVC(gamma=gamma) + + +@function_modifiers.config.when(clf='logistic') +def prefit_clf__logreg(penalty: str) -> base.ClassifierMixin: + """Returns an unfitted Logistic Regression classifier object. + + :param penalty: + :return: + """ + return linear_model.LogisticRegression(penalty) + + +@function_modifiers.extract_fields( + {'X_train': np.ndarray, 'X_test': np.ndarray, 'y_train': np.ndarray, 'y_test': np.ndarray}) +def train_test_split_func(feature_matrix: np.ndarray, + target: np.ndarray, + test_size_fraction: float, + shuffle_train_test_split: bool) -> Dict[str, np.ndarray]: + """Function that creates the training & test splits. + + It this then extracted out into constituent components and used downstream. + + :param feature_matrix: + :param target: + :param test_size_fraction: + :param shuffle_train_test_split: + :return: + """ + X_train, X_test, y_train, y_test = train_test_split( + feature_matrix, target, test_size=test_size_fraction, shuffle=shuffle_train_test_split + ) + return { + 'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test + } + + +def y_test_with_labels(y_test: np.ndarray, target_names: np.ndarray) -> np.ndarray: + """Adds labels to the target output.""" + return np.array([target_names[idx] for idx in y_test]) + + +def fit_clf(prefit_clf: base.ClassifierMixin, X_train: np.ndarray, y_train: np.ndarray) -> base.ClassifierMixin: + """Calls fit on the classifier object; it mutates it.""" + prefit_clf.fit(X_train, y_train) + return prefit_clf + + +def predicted_output(fit_clf: base.ClassifierMixin, X_test: np.ndarray) -> np.ndarray: + """Exercised the fit classifier to perform a prediction.""" + return fit_clf.predict(X_test) + + +def predicted_output_with_labels(predicted_output: np.ndarray, target_names: np.ndarray) -> np.ndarray: + """Replaces the predictions with the desired labels.""" + return np.array([target_names[idx] for idx in predicted_output]) + + +def classification_report(predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray) -> str: + """Returns a classification report.""" + return metrics.classification_report(y_test_with_labels, predicted_output_with_labels) + + +def confusion_matrix(predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray) -> str: + """Returns a confusion matrix report.""" + return metrics.confusion_matrix(y_test_with_labels, predicted_output_with_labels) diff --git a/examples/model_examples/scikit-learn/run.py b/examples/model_examples/scikit-learn/run.py index a7a88f50..bc0e9ee2 100644 --- a/examples/model_examples/scikit-learn/run.py +++ b/examples/model_examples/scikit-learn/run.py @@ -1,25 +1,64 @@ -from sklearn import datasets +""" +Example script showing how one might setup a generic model training pipeline that is quickly configurable. +""" from hamilton import driver from hamilton import base -import model_logic +import my_train_evaluate_logic +import digit_loader +import iris_loader -digits = datasets.load_digits() -config_or_data = { - 'digit_data': digits.images, # TODO: move this to loaders.py - 'digits_targets': digits.target, # TODO: move this to loaders.py - 'gamma': 0.001, - 'test_size_fraction': 0.5, - 'shuffle_train_test_split': False, - 'clf': 'svm' -} +def get_data_loader(data_set: str): + """Returns the module to load that will procur data -- the data loaders all have to define the same functions.""" + if data_set == 'iris': + return iris_loader + elif data_set == 'digits': + return digit_loader + else: + raise ValueError(f'Unknown data_name {data_set}.') -adapter = base.SimplePythonGraphAdapter(base.DictResult()) -dr = driver.Driver(config_or_data, model_logic, adapter=adapter) +def get_model_config(model_type: str) -> dict: + """Returns model type specific configuration""" + if model_type == 'svm': + return {'clf': 'svm', 'gamma': 0.001} + elif model_type == 'logistic': + return {'logistic': 'svm', 'penalty': 'l2'} + else: + raise ValueError(f'Unsupported model {model_type}.') -results = dr.execute(['classification_report']) -for k, v in results.items(): - print(v) + +if __name__ == '__main__': + import sys + if len(sys.argv) < 3: + print('Error: required arguments are [iris|digits] [svm|logistic]') + sys.exit(1) + _data_set = sys.argv[1] # the data set to load + _model_type = sys.argv[2] # the model type to fit and evaluate with + + dag_config = { + 'test_size_fraction': 0.5, + 'shuffle_train_test_split': True, + } + # augment config + dag_config.update(get_model_config(_model_type)) + # get module with functions to load data + data_module = get_data_loader(_data_set) + # set the desired result container we want + adapter = base.SimplePythonGraphAdapter(base.DictResult()) + """ + What's cool about this, is that by simply changing the `dag_config` and the `data_module` we can + reuse the logic in the `my_train_evaluate_logic` module very easily for different contexts and purposes if + want to setup a generic model fitting and prediction dataflow! + E.g. if we want to support a new data set, then we just need to add a new data loading module. + E.g. if we want to support a new model type, then we just need to add a single conditional function + to my_train_evaluate_logic. + """ + dr = driver.Driver(dag_config, data_module, my_train_evaluate_logic, adapter=adapter) + # ensure you have done "pip install sf-hamilton[visualization]" for the following to work: + # dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'], './model_dag.dot', {}) + results = dr.execute(['classification_report', 'confusion_matrix', 'fit_clf']) + for k, v in results.items(): + print(k, ':\n', v) From f5dfc77aed4d50082ec644b9a21ba94cef6389e8 Mon Sep 17 00:00:00 2001 From: skrawczyk Date: Wed, 9 Feb 2022 11:38:32 -0800 Subject: [PATCH 4/5] Adds unit tests for extract_fields decorator Helps prove things work as intended! --- hamilton/function_modifiers.py | 4 +- tests/test_function_modifiers.py | 109 +++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index 9989d940..d90678dc 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -208,7 +208,7 @@ def __init__(self, fields: dict, fill_with: Any = None): if not fields: raise InvalidDecoratorException('Error an empty dict, or no dict, passed to extract_fields decorator.') elif not isinstance(fields, dict): - raise InvalidDecoratorException(f'Error, please pass in a dict or typed dict, not {type(fields)}') + raise InvalidDecoratorException(f'Error, please pass in a dict, not {type(fields)}') else: errors = [] for field, field_type in fields.items(): @@ -232,7 +232,7 @@ def validate(self, fn: Callable): output_type = inspect.signature(fn).return_annotation if typing_inspect.is_generic_type(output_type): base = typing_inspect.get_origin(output_type) - if base == dict or base == typing.Dict: + if base == dict or base == typing.Dict: # different python versions return different things 3.7+ vs 3.6. pass else: raise InvalidDecoratorException( diff --git a/tests/test_function_modifiers.py b/tests/test_function_modifiers.py index 3eab1bed..c3a6f27c 100644 --- a/tests/test_function_modifiers.py +++ b/tests/test_function_modifiers.py @@ -1,5 +1,6 @@ from typing import Any, List, Dict +import numpy as np import pandas as pd import pytest @@ -368,3 +369,111 @@ def config_when_fn() -> int: annotation = function_modifiers.config.when(key='value', name='new_function_name') assert annotation.resolve(config_when_fn, {'key': 'value'}).__name__ == 'new_function_name' + + +@pytest.mark.parametrize('fields', [ + (None), # empty + ('string_input'), # not a dict + (['string_input']), # not a dict + ({}), # empty dict + ({1: 'string', 'field': str}), # invalid dict + ({'field': lambda x: x, 'field2': int} ), # invalid dict +]) +def test_extract_fields_constructor_errors(fields): + with pytest.raises(function_modifiers.InvalidDecoratorException): + function_modifiers.extract_fields(fields) + + +@pytest.mark.parametrize('fields', [ + ({'field': np.ndarray, 'field2': str}), + ({'field': dict, 'field2': int, 'field3': list, 'field4': float, 'field5': str}), +]) +def test_extract_fields_constructor_happy(fields): + """Tests that we are happy with good arguments.""" + function_modifiers.extract_fields(fields) + + +@pytest.mark.parametrize('return_type', [ + (dict), + (Dict), + (Dict[str, str]), + (Dict[str, Any]), +]) +def test_extract_fields_validate_happy(return_type): + def return_dict() -> return_type: + return {} + + annotation = function_modifiers.extract_fields({'test': int}) + annotation.validate(return_dict) + + +@pytest.mark.parametrize('return_type', [ + (int), (list), (np.ndarray), (pd.DataFrame) +]) +def test_extract_fields_validate_errors(return_type): + def return_dict() -> return_type: + return {} + annotation = function_modifiers.extract_fields({'test': int}) + with pytest.raises(function_modifiers.InvalidDecoratorException): + annotation.validate(return_dict) + + +def test_valid_extract_fields(): + """Tests whole extract_fields decorator.""" + annotation = function_modifiers.extract_fields({'col_1': list, 'col_2': int, 'col_3': np.ndarray}) + + def dummy_dict_generator() -> dict: + """dummy doc""" + return {'col_1': [1, 2, 3, 4], + 'col_2': 1, + 'col_3': np.ndarray([1, 2, 3, 4])} + + nodes = list(annotation.expand_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator)) + assert len(nodes) == 4 + assert nodes[0] == node.Node(name=dummy_dict_generator.__name__, + typ=dict, + doc_string=dummy_dict_generator.__doc__, + callabl=dummy_dict_generator) + assert nodes[1].name == 'col_1' + assert nodes[1].type == list + assert nodes[1].documentation == 'dummy doc' # we default to base function doc. + assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[2].name == 'col_2' + assert nodes[2].type == int + assert nodes[2].documentation == 'dummy doc' + assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[3].name == 'col_3' + assert nodes[3].type == np.ndarray + assert nodes[3].documentation == 'dummy doc' + assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + + +def test_extract_fields_fill_with(): + def dummy_dict() -> dict: + """dummy doc""" + return {'col_1': [1, 2, 3, 4], + 'col_2': 1, + 'col_3': np.ndarray([1, 2, 3, 4])} + + annotation = function_modifiers.extract_fields({'col_2': int, 'col_4': float}, fill_with=1.0) + original_node, extracted_field_node, missing_field_node = annotation.expand_node(node.Node.from_fn(dummy_dict), + {}, + dummy_dict) + original_dict = original_node.callable() + extracted_field = extracted_field_node.callable(dummy_dict=original_dict) + missing_field = missing_field_node.callable(dummy_dict=original_dict) + assert extracted_field == 1 + assert missing_field == 1.0 + + +def test_extract_fields_no_fill_with(): + def dummy_dict() -> dict: + """dummy doc""" + return {'col_1': [1, 2, 3, 4], + 'col_2': 1, + 'col_3': np.ndarray([1, 2, 3, 4])} + + annotation = function_modifiers.extract_fields({'col_4': int}) + nodes = list(annotation.expand_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + with pytest.raises(function_modifiers.InvalidDecoratorException): + nodes[1].callable(dummy_dict=dummy_dict()) From b359b3b61d431f94ae5962acff4513bbfc57b639 Mon Sep 17 00:00:00 2001 From: skrawczyk Date: Wed, 9 Feb 2022 11:39:09 -0800 Subject: [PATCH 5/5] Fixes graphviz test I wonder if this is flakey somehow? Anyway adding this to see if circleci complains or not. Seems like there could be a version mismatch somewhere that causes this, i.e. my local env, versus what circleci installs, etc. --- tests/test_graph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index a5d20136..d527da98 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -364,8 +364,12 @@ def test_create_graphviz_graph(): '\tA -> C', '}', '']) + if '' in expected: + expected.remove('') digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph') actual = sorted(str(digraph).split('\n')) + if '' in actual: + actual.remove('') assert actual == expected