diff --git a/examples/model_examples/scikit-learn/digit_loader.py b/examples/model_examples/scikit-learn/digit_loader.py new file mode 100644 index 00000000..2db1a833 --- /dev/null +++ b/examples/model_examples/scikit-learn/digit_loader.py @@ -0,0 +1,24 @@ +import numpy as np + +from sklearn import datasets +from sklearn import utils + +""" +Module to load digit data. +""" + +def digit_data() -> utils.Bunch: + return datasets.load_digits() + + +def target(digit_data: utils.Bunch) -> np.ndarray: + return digit_data.target + + +def target_names(digit_data: utils.Bunch) -> np.ndarray: + return digit_data.target_names + + +def feature_matrix(digit_data: utils.Bunch) -> np.ndarray: + # return digit_data.images.reshape((len(digit_data), -1)) + return digit_data.data diff --git a/examples/model_examples/scikit-learn/iris_loader.py b/examples/model_examples/scikit-learn/iris_loader.py new file mode 100644 index 00000000..f82639fb --- /dev/null +++ b/examples/model_examples/scikit-learn/iris_loader.py @@ -0,0 +1,23 @@ +import numpy as np + +from sklearn import datasets +from sklearn import utils + +""" +Module to load iris data. +""" + +def iris_data() -> utils.Bunch: + return datasets.load_iris() + + +def target(iris_data: utils.Bunch) -> np.ndarray: + return iris_data.target + + +def target_names(iris_data: utils.Bunch) -> np.ndarray: + return iris_data.target_names + + +def feature_matrix(iris_data: utils.Bunch) -> np.ndarray: + return iris_data.data diff --git a/examples/model_examples/scikit-learn/my_train_evaluate_logic.py b/examples/model_examples/scikit-learn/my_train_evaluate_logic.py new file mode 100644 index 00000000..97628f0e --- /dev/null +++ b/examples/model_examples/scikit-learn/my_train_evaluate_logic.py @@ -0,0 +1,86 @@ +from typing import Dict, List + +import numpy as np +from sklearn import base +from sklearn import metrics +from sklearn import svm +from sklearn import linear_model +from sklearn.model_selection import train_test_split + +from hamilton import function_modifiers + + + +@function_modifiers.config.when(clf='svm') +def prefit_clf__svm(gamma: float = 0.001) -> base.ClassifierMixin: + """Returns an unfitted SVM classifier object. + + :param gamma: ... + :return: + """ + return svm.SVC(gamma=gamma) + + +@function_modifiers.config.when(clf='logistic') +def prefit_clf__logreg(penalty: str) -> base.ClassifierMixin: + """Returns an unfitted Logistic Regression classifier object. + + :param penalty: + :return: + """ + return linear_model.LogisticRegression(penalty) + + +@function_modifiers.extract_fields( + {'X_train': np.ndarray, 'X_test': np.ndarray, 'y_train': np.ndarray, 'y_test': np.ndarray}) +def train_test_split_func(feature_matrix: np.ndarray, + target: np.ndarray, + test_size_fraction: float, + shuffle_train_test_split: bool) -> Dict[str, np.ndarray]: + """Function that creates the training & test splits. + + It this then extracted out into constituent components and used downstream. + + :param feature_matrix: + :param target: + :param test_size_fraction: + :param shuffle_train_test_split: + :return: + """ + X_train, X_test, y_train, y_test = train_test_split( + feature_matrix, target, test_size=test_size_fraction, shuffle=shuffle_train_test_split + ) + return { + 'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test + } + + +def y_test_with_labels(y_test: np.ndarray, target_names: np.ndarray) -> np.ndarray: + """Adds labels to the target output.""" + return np.array([target_names[idx] for idx in y_test]) + + +def fit_clf(prefit_clf: base.ClassifierMixin, X_train: np.ndarray, y_train: np.ndarray) -> base.ClassifierMixin: + """Calls fit on the classifier object; it mutates it.""" + prefit_clf.fit(X_train, y_train) + return prefit_clf + + +def predicted_output(fit_clf: base.ClassifierMixin, X_test: np.ndarray) -> np.ndarray: + """Exercised the fit classifier to perform a prediction.""" + return fit_clf.predict(X_test) + + +def predicted_output_with_labels(predicted_output: np.ndarray, target_names: np.ndarray) -> np.ndarray: + """Replaces the predictions with the desired labels.""" + return np.array([target_names[idx] for idx in predicted_output]) + + +def classification_report(predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray) -> str: + """Returns a classification report.""" + return metrics.classification_report(y_test_with_labels, predicted_output_with_labels) + + +def confusion_matrix(predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray) -> str: + """Returns a confusion matrix report.""" + return metrics.confusion_matrix(y_test_with_labels, predicted_output_with_labels) diff --git a/examples/model_examples/scikit-learn/requirements.txt b/examples/model_examples/scikit-learn/requirements.txt new file mode 100644 index 00000000..3f69ad5c --- /dev/null +++ b/examples/model_examples/scikit-learn/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +sf-hamilton diff --git a/examples/model_examples/scikit-learn/run.py b/examples/model_examples/scikit-learn/run.py new file mode 100644 index 00000000..bc0e9ee2 --- /dev/null +++ b/examples/model_examples/scikit-learn/run.py @@ -0,0 +1,64 @@ +""" +Example script showing how one might setup a generic model training pipeline that is quickly configurable. +""" + +from hamilton import driver +from hamilton import base + +import my_train_evaluate_logic +import digit_loader +import iris_loader + + +def get_data_loader(data_set: str): + """Returns the module to load that will procur data -- the data loaders all have to define the same functions.""" + if data_set == 'iris': + return iris_loader + elif data_set == 'digits': + return digit_loader + else: + raise ValueError(f'Unknown data_name {data_set}.') + + +def get_model_config(model_type: str) -> dict: + """Returns model type specific configuration""" + if model_type == 'svm': + return {'clf': 'svm', 'gamma': 0.001} + elif model_type == 'logistic': + return {'logistic': 'svm', 'penalty': 'l2'} + else: + raise ValueError(f'Unsupported model {model_type}.') + + +if __name__ == '__main__': + import sys + if len(sys.argv) < 3: + print('Error: required arguments are [iris|digits] [svm|logistic]') + sys.exit(1) + _data_set = sys.argv[1] # the data set to load + _model_type = sys.argv[2] # the model type to fit and evaluate with + + dag_config = { + 'test_size_fraction': 0.5, + 'shuffle_train_test_split': True, + } + # augment config + dag_config.update(get_model_config(_model_type)) + # get module with functions to load data + data_module = get_data_loader(_data_set) + # set the desired result container we want + adapter = base.SimplePythonGraphAdapter(base.DictResult()) + """ + What's cool about this, is that by simply changing the `dag_config` and the `data_module` we can + reuse the logic in the `my_train_evaluate_logic` module very easily for different contexts and purposes if + want to setup a generic model fitting and prediction dataflow! + E.g. if we want to support a new data set, then we just need to add a new data loading module. + E.g. if we want to support a new model type, then we just need to add a single conditional function + to my_train_evaluate_logic. + """ + dr = driver.Driver(dag_config, data_module, my_train_evaluate_logic, adapter=adapter) + # ensure you have done "pip install sf-hamilton[visualization]" for the following to work: + # dr.visualize_execution(['classification_report', 'confusion_matrix', 'fit_clf'], './model_dag.dot', {}) + results = dr.execute(['classification_report', 'confusion_matrix', 'fit_clf']) + for k, v in results.items(): + print(k, ':\n', v) diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index ba3752db..d90678dc 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -1,7 +1,9 @@ import functools import functools import inspect +import typing from typing import Dict, Callable, Collection, Tuple, Union, Any, Type +import typing_inspect import pandas as pd @@ -190,6 +192,96 @@ def extractor_fn(column_to_extract: str = column, **kwargs) -> pd.Series: # avo return output_nodes + +class extract_fields(NodeExpander): + """Extracts fields from a dictionary of output.""" + + def __init__(self, fields: dict, fill_with: Any = None): + """Constructor for a modifier that expands a single function into the following nodes: + - n functions, each of which take in the original dict and output a specific field + - 1 function that outputs the original dict + + :param fields: Fields to extract. A dict of 'field_name' -> 'field_type'. + :param fill_with: If you want to extract a field that doesn't exist, do you want to fill it with a default value? + Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a field value. + """ + if not fields: + raise InvalidDecoratorException('Error an empty dict, or no dict, passed to extract_fields decorator.') + elif not isinstance(fields, dict): + raise InvalidDecoratorException(f'Error, please pass in a dict, not {type(fields)}') + else: + errors = [] + for field, field_type in fields.items(): + if not isinstance(field, str): + errors.append(f'{field} is not a string. All keys must be strings.') + if not isinstance(field_type, type): + errors.append(f'{field} does not declare a type. Instead it passes {field_type}.') + + if errors: + raise InvalidDecoratorException(f'Error, found these {errors}. ' + f'Please pass in a dict of string to types. ') + self.fields = fields + self.fill_with = fill_with + + def validate(self, fn: Callable): + """A function is invalid if it is not annotated with a dict or typing.Dict return type. + + :param fn: Function to validate. + :raises: InvalidDecoratorException If the function is not annotated with a dict or typing.Dict type as output. + """ + output_type = inspect.signature(fn).return_annotation + if typing_inspect.is_generic_type(output_type): + base = typing_inspect.get_origin(output_type) + if base == dict or base == typing.Dict: # different python versions return different things 3.7+ vs 3.6. + pass + else: + raise InvalidDecoratorException( + f'For extracting fields, output type must be a dict or typing.Dict, not: {output_type}') + elif output_type == dict: + pass + else: + raise InvalidDecoratorException( + f'For extracting fields, output type must be a dict or typing.Dict, not: {output_type}') + + def expand_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]: + """For each field to extract, output a node that extracts that field. Also, output the original TypedDict + generator. + + :param node_: + :param config: + :param fn: Function to extract columns from. Must output a dataframe. + :return: A collection of nodes -- + one for the original dataframe generator, and another for each column to extract. + """ + fn = node_.callable + base_doc = node_.documentation + + @functools.wraps(fn) + def dict_generator(*args, **kwargs): + dict_generated = fn(*args, **kwargs) + if self.fill_with is not None: + for field in self.fields: + if field not in dict_generated: + dict_generated[field] = self.fill_with + return dict_generated + + output_nodes = [node.Node(node_.name, typ=dict, doc_string=base_doc, callabl=dict_generator)] + + for field, field_type in self.fields.items(): + doc_string = base_doc # default doc string of base function. + + def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: # avoiding problems with closures + dt = kwargs[node_.name] + if field_to_extract not in dt: + raise InvalidDecoratorException(f'No such field: {field_to_extract} produced by {node_.name}. ' + f'It only produced {list(dt.keys())}') + return kwargs[node_.name][field_to_extract] + + output_nodes.append( + node.Node(field, field_type, doc_string, extractor_fn, input_types={node_.name: dict})) + return output_nodes + + # the following are empty functions that we can compare against to ensure that @does uses an empty function def _empty_function(): pass diff --git a/requirements.txt b/requirements.txt index 45785f6c..f1314b21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ dataclasses; python_version < '3.7' numpy pandas +typing_inspect diff --git a/tests/test_function_modifiers.py b/tests/test_function_modifiers.py index 3eab1bed..c3a6f27c 100644 --- a/tests/test_function_modifiers.py +++ b/tests/test_function_modifiers.py @@ -1,5 +1,6 @@ from typing import Any, List, Dict +import numpy as np import pandas as pd import pytest @@ -368,3 +369,111 @@ def config_when_fn() -> int: annotation = function_modifiers.config.when(key='value', name='new_function_name') assert annotation.resolve(config_when_fn, {'key': 'value'}).__name__ == 'new_function_name' + + +@pytest.mark.parametrize('fields', [ + (None), # empty + ('string_input'), # not a dict + (['string_input']), # not a dict + ({}), # empty dict + ({1: 'string', 'field': str}), # invalid dict + ({'field': lambda x: x, 'field2': int} ), # invalid dict +]) +def test_extract_fields_constructor_errors(fields): + with pytest.raises(function_modifiers.InvalidDecoratorException): + function_modifiers.extract_fields(fields) + + +@pytest.mark.parametrize('fields', [ + ({'field': np.ndarray, 'field2': str}), + ({'field': dict, 'field2': int, 'field3': list, 'field4': float, 'field5': str}), +]) +def test_extract_fields_constructor_happy(fields): + """Tests that we are happy with good arguments.""" + function_modifiers.extract_fields(fields) + + +@pytest.mark.parametrize('return_type', [ + (dict), + (Dict), + (Dict[str, str]), + (Dict[str, Any]), +]) +def test_extract_fields_validate_happy(return_type): + def return_dict() -> return_type: + return {} + + annotation = function_modifiers.extract_fields({'test': int}) + annotation.validate(return_dict) + + +@pytest.mark.parametrize('return_type', [ + (int), (list), (np.ndarray), (pd.DataFrame) +]) +def test_extract_fields_validate_errors(return_type): + def return_dict() -> return_type: + return {} + annotation = function_modifiers.extract_fields({'test': int}) + with pytest.raises(function_modifiers.InvalidDecoratorException): + annotation.validate(return_dict) + + +def test_valid_extract_fields(): + """Tests whole extract_fields decorator.""" + annotation = function_modifiers.extract_fields({'col_1': list, 'col_2': int, 'col_3': np.ndarray}) + + def dummy_dict_generator() -> dict: + """dummy doc""" + return {'col_1': [1, 2, 3, 4], + 'col_2': 1, + 'col_3': np.ndarray([1, 2, 3, 4])} + + nodes = list(annotation.expand_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator)) + assert len(nodes) == 4 + assert nodes[0] == node.Node(name=dummy_dict_generator.__name__, + typ=dict, + doc_string=dummy_dict_generator.__doc__, + callabl=dummy_dict_generator) + assert nodes[1].name == 'col_1' + assert nodes[1].type == list + assert nodes[1].documentation == 'dummy doc' # we default to base function doc. + assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[2].name == 'col_2' + assert nodes[2].type == int + assert nodes[2].documentation == 'dummy doc' + assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + assert nodes[3].name == 'col_3' + assert nodes[3].type == np.ndarray + assert nodes[3].documentation == 'dummy doc' + assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)} + + +def test_extract_fields_fill_with(): + def dummy_dict() -> dict: + """dummy doc""" + return {'col_1': [1, 2, 3, 4], + 'col_2': 1, + 'col_3': np.ndarray([1, 2, 3, 4])} + + annotation = function_modifiers.extract_fields({'col_2': int, 'col_4': float}, fill_with=1.0) + original_node, extracted_field_node, missing_field_node = annotation.expand_node(node.Node.from_fn(dummy_dict), + {}, + dummy_dict) + original_dict = original_node.callable() + extracted_field = extracted_field_node.callable(dummy_dict=original_dict) + missing_field = missing_field_node.callable(dummy_dict=original_dict) + assert extracted_field == 1 + assert missing_field == 1.0 + + +def test_extract_fields_no_fill_with(): + def dummy_dict() -> dict: + """dummy doc""" + return {'col_1': [1, 2, 3, 4], + 'col_2': 1, + 'col_3': np.ndarray([1, 2, 3, 4])} + + annotation = function_modifiers.extract_fields({'col_4': int}) + nodes = list(annotation.expand_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + with pytest.raises(function_modifiers.InvalidDecoratorException): + nodes[1].callable(dummy_dict=dummy_dict()) diff --git a/tests/test_graph.py b/tests/test_graph.py index a5d20136..d527da98 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -364,8 +364,12 @@ def test_create_graphviz_graph(): '\tA -> C', '}', '']) + if '' in expected: + expected.remove('') digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph') actual = sorted(str(digraph).split('\n')) + if '' in actual: + actual.remove('') assert actual == expected