-
Notifications
You must be signed in to change notification settings - Fork 74k
/
composite_tensor_variant.h
96 lines (75 loc) · 3.67 KB
/
composite_tensor_variant.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_
#define TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
namespace tensorflow {
class CompositeTensorVariantMetadata;
// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant.
//
// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class
// used to Python types that are supported by TensorFlow APIs. Example
// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`.
//
// `CompositeTensorVariant` decomposes the `ExtensionType` value into two
// parts:
//
// * `components`: A list of Tensors, which encodes the value's dynamic
// data -- i.e., data that may change for different executions of a graph.
// * `metadata`: A serialized TypeSpec, which encodes the value's
// static data -- i.e., data that is the same for all executions of a graph.
//
// CompositeTensorVariant can be stored in a Tensor with dtype=DT_VARIANT.
// Typically, extension type values are encoded with a scalar tensor containing
// a single CompositeTensorVariant value.
class CompositeTensorVariant {
public:
CompositeTensorVariant(const CompositeTensorVariantMetadata& metadata,
absl::Span<Tensor> flat_components);
CompositeTensorVariant();
CompositeTensorVariant(const CompositeTensorVariant& other);
CompositeTensorVariant& operator=(CompositeTensorVariant&& other) = default;
CompositeTensorVariant& operator=(const CompositeTensorVariant& other) =
delete;
// Returns the list of Tensor components that encode this value's dynamic
// data.
absl::Span<const Tensor> flat_components() const {
return absl::MakeConstSpan(flat_components_);
}
// Returns the serialized TypeSpec that encodes the value's static data.
const CompositeTensorVariantMetadata& metadata() const { return *metadata_; }
// Variant methods.
string TypeName() const { return kTypeName; }
// Updates `VariantTensorData` with an encoding for this value.
void Encode(VariantTensorData* data) const;
// Updates this value to match the encoding in a given `VariantTensorData`.
bool Decode(const VariantTensorData& data);
// Returns a string summary for this value.
string DebugString() const;
// Name of this type (used for variant serialization).
static constexpr const char kTypeName[] = "CompositeTensorVariant";
private:
// Tensor components for this value.
std::vector<Tensor> flat_components_;
// TypeSpec for this value. CompositeTensorVariantMetadata is a thin wrapper
// around a TypeSpecProto, which is used to retain flexibility to change the
// variant encoding.
//
// Note: we use a unique_ptr, because header files in the kernels/ directory
// are not allowed to import .pb.h files.
std::unique_ptr<CompositeTensorVariantMetadata> metadata_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_