From 03f16fb4c3005ef1a35058f7bbac27947024113c Mon Sep 17 00:00:00 2001 From: tf-transform-team Date: Fri, 20 Oct 2017 12:02:45 -0700 Subject: [PATCH] Project import generated by Copybara. PiperOrigin-RevId: 172914980 --- tensorflow_transform/analyzers.py | 48 +++++++++++-------- tensorflow_transform/beam/impl.py | 13 +++-- .../coders/example_proto_coder.py | 18 +++---- .../coders/example_proto_coder_test.py | 30 ++++-------- 4 files changed, 52 insertions(+), 57 deletions(-) diff --git a/tensorflow_transform/analyzers.py b/tensorflow_transform/analyzers.py index 8dbc58db..2ae6b448 100644 --- a/tensorflow_transform/analyzers.py +++ b/tensorflow_transform/analyzers.py @@ -49,28 +49,36 @@ class Analyzer(object): Args: inputs: The inputs to the analyzer. - output_tensors_and_is_asset: List of pairs of `(Tensor, bool)` for each - output. The `Tensor`s are typically placeholders; they will be later - be replaced with analysis results. The boolean value states whether this - Tensor represents an asset filename or not. + output_dtype_shape_and_is_asset: List of tuples of `(DType, Shape, bool)` + for each output. A tf.placeholder with the given DType and Shape will be + constructed to represent the output of the analyzer, and this placeholder + will eventually be replaced by the actual value of the analyzer. The + boolean value states whether this Tensor represents an asset filename or + not. spec: A description of the computation to be done. + name: Similar to a TF op name. Used to define a unique scope for this + analyzer, which can be used for debugging info. Raises: ValueError: If the inputs are not all `Tensor`s. """ - def __init__(self, inputs, output_tensors_and_is_asset, spec): + def __init__(self, inputs, output_dtype_shape_and_is_asset, spec, name): for tensor in inputs: if not isinstance(tensor, tf.Tensor): raise ValueError('Analyzers can only accept `Tensor`s as inputs') self._inputs = inputs - for output_tensor, is_asset in output_tensors_and_is_asset: - if is_asset and output_tensor.dtype != tf.string: - raise ValueError(('Tensor {} cannot represent an asset, because it is ' - 'not a string.').format(output_tensor.name)) - self._outputs = [output_tensor - for output_tensor, _ in output_tensors_and_is_asset] - self._output_is_asset_map = dict(output_tensors_and_is_asset) + self._outputs = [] + self._output_is_asset_map = {} + with tf.name_scope(name) as scope: + self._name = scope + for dtype, shape, is_asset in output_dtype_shape_and_is_asset: + output_tensor = tf.placeholder(dtype, shape) + if is_asset and output_tensor.dtype != tf.string: + raise ValueError(('Tensor {} cannot represent an asset, because it ' + 'is not a string.').format(output_tensor.name)) + self._outputs.append(output_tensor) + self._output_is_asset_map[output_tensor] = is_asset self._spec = spec tf.add_to_collection(ANALYZER_COLLECTION, self) @@ -86,6 +94,10 @@ def outputs(self): def spec(self): return self._spec + @property + def name(self): + return self._name + def output_is_asset(self, output_tensor): return self._output_is_asset_map[output_tensor] @@ -131,11 +143,9 @@ def _numeric_combine(x, combiner_type, reduce_instance_dims=True): # If reducing over batch dimensions, with unknown shape, the result will # also have unknown shape. shape = None - with tf.name_scope(combiner_type): - spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims) - return Analyzer([x], - [(tf.placeholder(x.dtype, shape), False)], - spec).outputs[0] + spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims) + return Analyzer( + [x], [(x.dtype, shape, False)], spec, combiner_type).outputs[0] def min(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin @@ -381,9 +391,7 @@ def uniques(x, top_k=None, frequency_threshold=None, spec = UniquesSpec(tf.string, top_k, frequency_threshold, vocab_filename, store_frequency) - return Analyzer([x], - [(tf.placeholder(tf.string, []), True)], - spec).outputs[0] + return Analyzer([x], [(tf.string, [], True)], spec, 'uniques').outputs[0] class QuantilesSpec(object): diff --git a/tensorflow_transform/beam/impl.py b/tensorflow_transform/beam/impl.py index f0916d5a..2b97ffa2 100644 --- a/tensorflow_transform/beam/impl.py +++ b/tensorflow_transform/beam/impl.py @@ -513,9 +513,8 @@ def __init__(self, analyzers, base_temp_dir): def expand(self, analyzer_input_values): # For each analyzer output, look up its input values (by tensor name) # and run the analyzer on these values. - # result = {} - for idx, analyzer in enumerate(self._analyzers): + for analyzer in self._analyzers: temp_assets_dir = _make_unique_temp_dir(self._base_temp_dir) tf.gfile.MkDir(temp_assets_dir) analyzer_impl = analyzer_impls._impl_for_analyzer( @@ -525,10 +524,10 @@ def expand(self, analyzer_input_values): assert len(analyzer.inputs) == 1 output_pcolls = ( analyzer_input_values - | 'Extract_%d' % idx >> beam.Map( + | 'ExtractInput[%s]' % analyzer.name >> beam.Map( lambda batch, key: batch[key], key=analyzer.inputs[0].name) - | 'Analyze_%d' % idx >> analyzer_impl) + | 'Analyze[%s]' % analyzer.name >> analyzer_impl) assert len(analyzer.outputs) == len(output_pcolls), ( 'Analyzer outputs don\'t match the expected outputs from the ' 'Analyzer definition: %d != %d' % @@ -537,7 +536,7 @@ def expand(self, analyzer_input_values): for collection_idx, (tensor, pcoll) in enumerate( zip(analyzer.outputs, output_pcolls)): is_asset = analyzer.output_is_asset(tensor) - pcoll |= ('WrapAsTensorValue_%d_%d' % (idx, collection_idx) + pcoll |= ('WrapAsTensorValue[%s][%d]' % (analyzer.name, collection_idx) >> beam.Map(_TensorValue, is_asset)) result[tensor.name] = pcoll return result @@ -711,7 +710,7 @@ def expand(self, dataset): graph, inputs, analyzer_inputs, unbound_saved_model_dir) saved_model_dir = ( tensor_pcoll_mapping - | 'CreateSavedModelForAnaylzerInputs_%d' % level + | 'CreateSavedModelForAnaylzerInputs[%d]' % level >> _ReplaceTensorsWithConstants( unbound_saved_model_dir, base_temp_dir, input_values.pipeline)) @@ -719,7 +718,7 @@ def expand(self, dataset): # analyzers. analyzer_input_values = ( input_values - | 'ComputeAnalyzerInputs_%d' % level >> beam.ParDo( + | 'ComputeAnalyzerInputs[%d]' % level >> beam.ParDo( _RunMetaGraphDoFn( input_schema, analyzer_inputs_schema, diff --git a/tensorflow_transform/coders/example_proto_coder.py b/tensorflow_transform/coders/example_proto_coder.py index c45d2742..55bfd1a6 100644 --- a/tensorflow_transform/coders/example_proto_coder.py +++ b/tensorflow_transform/coders/example_proto_coder.py @@ -122,19 +122,21 @@ class _FixedLenFeatureHandler(object): def __init__(self, name, feature_spec): self._name = name self._np_dtype = feature_spec.dtype.as_numpy_dtype + self._default_value = feature_spec.default_value self._value_fn = _make_feature_value_fn(feature_spec.dtype) self._shape = feature_spec.shape self._rank = len(feature_spec.shape) + if self._rank > 0 and self._default_value: + raise ValueError('FixedLenFeature %r got default value for rank > 0, ' + 'only scalar default values are supported' + % (self._name,)) + if isinstance(self._default_value, list): + raise ValueError('FixedLenFeature %r got non-scalar default value, ' + 'only scalar default values are supported' % + (self._name,)) self._size = 1 for dim in feature_spec.shape: self._size *= dim - self._default_value = feature_spec.default_value - if self._default_value: - if list(np.asarray(self._default_value).shape) != self._shape: - raise ValueError( - 'FixedLenFeature %r got default value with incorrect shape' % - (self._name,)) - self._default_value = np.asarray(self._default_value).reshape(-1).tolist() @property def name(self): @@ -152,7 +154,7 @@ def parse_value(self, feature_map): feature = feature_map[self._name] values = self._value_fn(feature) elif self._default_value is not None: - values = self._default_value + values = [self._default_value] else: values = [] diff --git a/tensorflow_transform/coders/example_proto_coder_test.py b/tensorflow_transform/coders/example_proto_coder_test.py index 0629152e..60e8b75a 100644 --- a/tensorflow_transform/coders/example_proto_coder_test.py +++ b/tensorflow_transform/coders/example_proto_coder_test.py @@ -168,16 +168,8 @@ def test_example_proto_coder(self): def test_example_proto_coder_default_value(self): input_schema = dataset_schema.from_feature_spec({ - 'scalar_feature_3': - tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=1.0), - '1d_vector_feature': - tf.FixedLenFeature( - shape=[1], dtype=tf.float32, default_value=[2.0]), - '2d_vector_feature': - tf.FixedLenFeature( - shape=[2, 2], - dtype=tf.float32, - default_value=[[1.0, 2.0], [3.0, 4.0]]) + 'scalar_feature_3': tf.FixedLenFeature(shape=[], dtype=tf.float32, + default_value=1.0), }) coder = example_proto_coder.ExampleProtoCoder(input_schema) @@ -193,31 +185,25 @@ def test_example_proto_coder_default_value(self): # Assert the data is decoded into the expected format. expected_decoded = { 'scalar_feature_3': 1.0, - '1d_vector_feature': [2.0], - '2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]] } decoded = coder.decode(data) np.testing.assert_equal(expected_decoded, decoded) def test_example_proto_coder_bad_default_value(self): input_schema = dataset_schema.from_feature_spec({ - '1d_vector_feature': - tf.FixedLenFeature( - shape=[2], dtype=tf.float32, default_value=[1.0]), + 'scalar_feature_2': tf.FixedLenFeature(shape=[2], dtype=tf.float32, + default_value=[1.0, 2.0]), }) with self.assertRaisesRegexp(ValueError, - 'got default value with incorrect shape'): + 'only scalar default values are supported'): example_proto_coder.ExampleProtoCoder(input_schema) input_schema = dataset_schema.from_feature_spec({ - '2d_vector_feature': - tf.FixedLenFeature( - shape=[2, 3], - dtype=tf.float32, - default_value=[[1.0, 1.0], [1.0]]), + 'scalar_feature_2': tf.FixedLenFeature(shape=[], dtype=tf.float32, + default_value=[1.0]), }) with self.assertRaisesRegexp(ValueError, - 'got default value with incorrect shape'): + 'only scalar default values are supported'): example_proto_coder.ExampleProtoCoder(input_schema) def test_example_proto_coder_picklable(self):