Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions tensorflow_model_analysis/api/impl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,39 @@ def _ExtractOutput( # pylint: disable=invalid-name
main=_ExtractOutputDoFn.OUTPUT_TAG_METRICS)


def PredictExtractor(eval_saved_model_path, add_metrics_callbacks,
shared_handle, desired_batch_size):
# Map function which loads and runs the eval_saved_model against every
# example, yielding an types.ExampleAndExtracts containing a
# FeaturesPredictionsLabels value (where key is 'fpl').
return types.Extractor(
stage_name='Predict',
ptransform=predict_extractor.TFMAPredict(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks,
shared_handle=shared_handle,
desired_batch_size=desired_batch_size))


@beam.ptransform_fn
def Extract(examples, extractors):
"""Performs Extractions serially in provided order."""
augmented = examples

for extractor in extractors:
augmented = augmented | extractor.stage_name >> extractor.ptransform

return augmented


@beam.ptransform_fn
# No typehint for output type, since it's a multi-output DoFn result that
# Beam doesn't support typehints for yet (BEAM-3280).
def Evaluate(
# pylint: disable=invalid-name
examples,
eval_saved_model_path,
extractors = None,
add_metrics_callbacks = None,
slice_spec = None,
desired_batch_size = None,
Expand All @@ -309,6 +335,8 @@ def Evaluate(
(e.g. string containing CSV row, TensorFlow.Example, etc).
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
the saved_model.pb file.
extractors: Optional list of Extractors to execute prior to slicing and
aggregating the metrics. If not provided, a default set will be run.
add_metrics_callbacks: Optional list of callbacks for adding additional
metrics to the graph. The names of the metrics added by the callbacks
should not conflict with existing metrics, or metrics added by other
Expand Down Expand Up @@ -349,24 +377,22 @@ def add_metrics_callback(features_dict, predictions_dict, labels):

shared_handle = shared.Shared()

if not extractors:
extractors = [
PredictExtractor(eval_saved_model_path, add_metrics_callbacks,
shared_handle, desired_batch_size),
]

# pylint: disable=no-value-for-parameter
return (
examples
# Our diagnostic outputs, pass types.ExampleAndExtracts throughout,
# however our aggregating functions do not use this interface.
| 'ToExampleAndExtracts' >>
beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={}))
| Extract(extractors=extractors)

# Map function which loads and runs the eval_saved_model against every
# example, yielding an types.ExampleAndExtracts containing a
# FeaturesPredictionsLabels value (where key is 'fpl').
| 'Predict' >> predict_extractor.TFMAPredict(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks,
shared_handle=shared_handle,
desired_batch_size=desired_batch_size)

# Input: one example fpl at a time
# Input: one example at a time
# Output: one fpl example per slice key (notice that the example turns
# into n, replicated once per applicable slice key)
| 'Slice' >> slice_api.Slice(slice_spec)
Expand Down Expand Up @@ -395,6 +421,7 @@ def BuildDiagnosticTable(
# pylint: disable=invalid-name
examples,
eval_saved_model_path,
extractors = None,
desired_batch_size = None):
"""Build diagnostics for the spacified EvalSavedModel and example collection.

Expand All @@ -403,18 +430,24 @@ def BuildDiagnosticTable(
(e.g. string containing CSV row, TensorFlow.Example, etc).
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
the saved_model.pb file.
extractors: Optional list of Extractors to execute prior to slicing and
aggregating the metrics. If not provided, a default set will be run.
desired_batch_size: Optional batch size for batching in Predict and
Aggregate.

Returns:
PCollection of ExampleAndExtracts
"""

if not extractors:
extractors = [
PredictExtractor(eval_saved_model_path, None, shared.Shared(),
desired_batch_size),
types.Extractor(
stage_name='ExtractFeatures',
ptransform=feature_extractor.ExtractFeatures()),
]
return (examples
| 'ToExampleAndExtracts' >>
beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={}))
| 'Predict' >> predict_extractor.TFMAPredict(
eval_saved_model_path,
add_metrics_callbacks=None,
shared_handle=shared.Shared(),
desired_batch_size=desired_batch_size)
| 'ExtractFeatures' >> feature_extractor.ExtractFeatures())
| Extract(extractors=extractors))
5 changes: 5 additions & 0 deletions tensorflow_model_analysis/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import copy

import apache_beam as beam
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -66,6 +67,10 @@ def is_tensor(obj):
DictOfExtractedValues = Dict[Text, Any]


Extractor = NamedTuple('Extractor', [('stage_name', bytes),
('ptransform', beam.PTransform)])


class ExampleAndExtracts(
NamedTuple('ExampleAndExtracts', [('example', bytes),
('extracts', DictOfExtractedValues)])):
Expand Down