-
Notifications
You must be signed in to change notification settings - Fork 74k
/
tf2xla_util.h
224 lines (180 loc) · 9.03 KB
/
tf2xla_util.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
#include <unordered_map>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// ValidateConfig returns OK iff config is valid.
Status ValidateConfig(const tf2xla::Config& config);
// Modifies <graph_def> to include placeholders for each fed tensor, and
// update references to the fed tensors to refer to the placeholders.
// The existing nodes referenced by the feeds are not removed or modified
// (except where their input edges are modified by the replacement of other
// feeds).
Status AddPlaceholdersForFeeds(
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out);
// Returns node:port for the given <id>.
string TensorIdToString(const tf2xla::TensorId& id);
// Updates the sharding of <n> based on the sharding of its neighbors.
// If <out_edges> is true, outgoing edges from <n> are considered; else incoming
// edges are considered.
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
// Add an allowed data type to the AttrConstraint with the given name.
void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
KernelDef* kdef);
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
// a function call, or the node has function attrs).
class AssociatedFunctionInfo {
public:
enum AssociatedFunctionType {
kFunctionAttr = 0,
kFunctionCallNode = 1,
kSymbolicGradient = 2,
};
// The function is an attr of the node.
static AssociatedFunctionInfo FunctionAttr(const string& func_name,
const AttrValueMap& attrs,
const string& attr_name) {
return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
}
// The node is a function call.
static AssociatedFunctionInfo FunctionCall(const string& func_name,
const AttrValueMap& attrs) {
// attr_name will not be used in this case.
return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
/*attr_name=*/"");
}
// The node is a SymbolicGradient op.
static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
const AttrValueMap& attrs) {
// attr_name will not be used in this case.
return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
/*attr_name=*/"");
}
AssociatedFunctionType type() const { return type_; }
const string& func_name() const { return func_name_; }
const string& attr_name() const { return attr_name_; }
const AttrValueMap& attrs() const { return attrs_; }
private:
AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
const AttrValueMap& attrs, const string& attr_name)
: type_(type),
func_name_(func_name),
attrs_(attrs),
attr_name_(attr_name) {}
// Available for all instances.
AssociatedFunctionType type_;
string func_name_;
AttrValueMap attrs_;
// Only available if the function is defined in an attr.
string attr_name_;
};
// Returns if the NodeDef has associated function.
bool HasAssociatedFunction(const NodeDef& node_def,
const FunctionLibraryDefinition* fld);
// Gets functions associated with the node. Current cases:
// 1. For function call node, its function name;
// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
// and returned attrs will be this node's attributes;
// 3. For nodes like XlaWhile/XlaIf, all their function attributes.
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, const FunctionLibraryDefinition* fld);
// Changes associated functions for the node. Current cases:
// 1. For function call node, creates a new node with the new function name and
// remove the old node;
// 2. For SymbolicGradient op, add or replace GradientDef in
// FunctionLibraryDefinition;
// 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,
const string& rewritten_function_name);
// Class to act as cache for FunctionLibraryRuntime::Handle objects.
class CachedFunctionHandles {
public:
CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {}
// Populates `handle` for requested function and attributes. If we have
// instantiated the function with the same attributes before, `handle` will be
// cached handle; otherwise instantiate the function and populate `handle`.
Status GetOrInstantiate(const string& func_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle);
// Releases all handles in the cache. Returns first non-OK status if any;
// returns OK otherwise.
Status ReleaseAllHandles();
~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); }
private:
FunctionLibraryRuntime* flr_;
std::map<string, FunctionLibraryRuntime::Handle> handles_;
TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
};
// Struct for node's output edge info.
struct OutEdgeInfo {
Node* dst;
int src_output, dst_input;
};
// Replaces node `n` with a new node whose NodeDef is `node_def`.
StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
// Helper function that builds an Identity node.
StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
DataType dtype, const Node* input,
std::optional<string> requested_device);
// For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
// body functions to use the Const nodes instead of original _Arg nodes.
//
// For example, say we have the following computation:
// shape = constant_op.constant([1])
// return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
// If we do not rewrite then/else function, they will use _Arg node as shape
// input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
// time constant, so XLA compilation will fail. This rewriting process will
// change the shape input to Const node.
Status PropagateConstIntoFunctionalNodes(
Graph* g, const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld);
// Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld);
// Finds the following pattern in the graph:
// 1) EmptyTensorList -> forward While op -> backward While op,
// 2) in forward While op, a Const node is pushed,
// 3) in backward While op, data is popped from the tensor list.
// And rewrites backward While op to use Const node instead of TensorListPopBack
// result.
// TODO(b/128633174) remove the TensorList and related TensorList ops.
Status RewriteTensorListWithConstElement(Graph* g,
FunctionLibraryDefinition* fld);
inline bool IsConstTraversableOpType(const Node* node) {
return node->type_string() == "Identity" ||
node->type_string() == "IdentityN" || node->IsWhileNode();
}
// Determines whether a loop body is invariant for the given argument index.
StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
const FunctionLibraryDefinition* lookup_fld);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_