From 1ea7b29f2fe93bbdccdcced234b1c7da60078176 Mon Sep 17 00:00:00 2001 From: tf-model-analysis-team Date: Tue, 11 Sep 2018 11:19:06 -0700 Subject: [PATCH] Project import generated by Copybara. PiperOrigin-RevId: 212486359 --- .../api/impl/evaluate.py | 65 ++++++++++++++----- tensorflow_model_analysis/types.py | 5 ++ 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/tensorflow_model_analysis/api/impl/evaluate.py b/tensorflow_model_analysis/api/impl/evaluate.py index fbdebf2bb3..8f9e9daed0 100644 --- a/tensorflow_model_analysis/api/impl/evaluate.py +++ b/tensorflow_model_analysis/api/impl/evaluate.py @@ -288,6 +288,31 @@ def _ExtractOutput( # pylint: disable=invalid-name main=_ExtractOutputDoFn.OUTPUT_TAG_METRICS) +def PredictExtractor(eval_saved_model_path, add_metrics_callbacks, + shared_handle, desired_batch_size): + # Map function which loads and runs the eval_saved_model against every + # example, yielding an types.ExampleAndExtracts containing a + # FeaturesPredictionsLabels value (where key is 'fpl'). + return types.Extractor( + stage_name='Predict', + ptransform=predict_extractor.TFMAPredict( + eval_saved_model_path=eval_saved_model_path, + add_metrics_callbacks=add_metrics_callbacks, + shared_handle=shared_handle, + desired_batch_size=desired_batch_size)) + + +@beam.ptransform_fn +def Extract(examples, extractors): + """Performs Extractions serially in provided order.""" + augmented = examples + + for extractor in extractors: + augmented = augmented | extractor.stage_name >> extractor.ptransform + + return augmented + + @beam.ptransform_fn # No typehint for output type, since it's a multi-output DoFn result that # Beam doesn't support typehints for yet (BEAM-3280). @@ -295,6 +320,7 @@ def Evaluate( # pylint: disable=invalid-name examples, eval_saved_model_path, + extractors = None, add_metrics_callbacks = None, slice_spec = None, desired_batch_size = None, @@ -309,6 +335,8 @@ def Evaluate( (e.g. string containing CSV row, TensorFlow.Example, etc). eval_saved_model_path: Path to EvalSavedModel. This directory should contain the saved_model.pb file. + extractors: Optional list of Extractors to execute prior to slicing and + aggregating the metrics. If not provided, a default set will be run. add_metrics_callbacks: Optional list of callbacks for adding additional metrics to the graph. The names of the metrics added by the callbacks should not conflict with existing metrics, or metrics added by other @@ -349,6 +377,12 @@ def add_metrics_callback(features_dict, predictions_dict, labels): shared_handle = shared.Shared() + if not extractors: + extractors = [ + PredictExtractor(eval_saved_model_path, add_metrics_callbacks, + shared_handle, desired_batch_size), + ] + # pylint: disable=no-value-for-parameter return ( examples @@ -356,17 +390,9 @@ def add_metrics_callback(features_dict, predictions_dict, labels): # however our aggregating functions do not use this interface. | 'ToExampleAndExtracts' >> beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={})) + | Extract(extractors=extractors) - # Map function which loads and runs the eval_saved_model against every - # example, yielding an types.ExampleAndExtracts containing a - # FeaturesPredictionsLabels value (where key is 'fpl'). - | 'Predict' >> predict_extractor.TFMAPredict( - eval_saved_model_path=eval_saved_model_path, - add_metrics_callbacks=add_metrics_callbacks, - shared_handle=shared_handle, - desired_batch_size=desired_batch_size) - - # Input: one example fpl at a time + # Input: one example at a time # Output: one fpl example per slice key (notice that the example turns # into n, replicated once per applicable slice key) | 'Slice' >> slice_api.Slice(slice_spec) @@ -395,6 +421,7 @@ def BuildDiagnosticTable( # pylint: disable=invalid-name examples, eval_saved_model_path, + extractors = None, desired_batch_size = None): """Build diagnostics for the spacified EvalSavedModel and example collection. @@ -403,18 +430,24 @@ def BuildDiagnosticTable( (e.g. string containing CSV row, TensorFlow.Example, etc). eval_saved_model_path: Path to EvalSavedModel. This directory should contain the saved_model.pb file. + extractors: Optional list of Extractors to execute prior to slicing and + aggregating the metrics. If not provided, a default set will be run. desired_batch_size: Optional batch size for batching in Predict and Aggregate. Returns: PCollection of ExampleAndExtracts """ + + if not extractors: + extractors = [ + PredictExtractor(eval_saved_model_path, None, shared.Shared(), + desired_batch_size), + types.Extractor( + stage_name='ExtractFeatures', + ptransform=feature_extractor.ExtractFeatures()), + ] return (examples | 'ToExampleAndExtracts' >> beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={})) - | 'Predict' >> predict_extractor.TFMAPredict( - eval_saved_model_path, - add_metrics_callbacks=None, - shared_handle=shared.Shared(), - desired_batch_size=desired_batch_size) - | 'ExtractFeatures' >> feature_extractor.ExtractFeatures()) + | Extract(extractors=extractors)) diff --git a/tensorflow_model_analysis/types.py b/tensorflow_model_analysis/types.py index 9d5dd4326c..2da42488b7 100644 --- a/tensorflow_model_analysis/types.py +++ b/tensorflow_model_analysis/types.py @@ -20,6 +20,7 @@ import copy +import apache_beam as beam import numpy as np import tensorflow as tf @@ -66,6 +67,10 @@ def is_tensor(obj): DictOfExtractedValues = Dict[Text, Any] +Extractor = NamedTuple('Extractor', [('stage_name', bytes), + ('ptransform', beam.PTransform)]) + + class ExampleAndExtracts( NamedTuple('ExampleAndExtracts', [('example', bytes), ('extracts', DictOfExtractedValues)])):