/
tf2xla_defs.h
65 lines (55 loc) · 2.67 KB
/
tf2xla_defs.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_DEFS_H_
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_DEFS_H_
#include <array>
#include "absl/strings/string_view.h"
namespace tensorflow {
// Marks a node for XLA compilation. The attribute value indicates the
// compilation device type.
inline constexpr absl::string_view kCompileDeviceTypeAttr =
"_xla_compile_device_type";
// Marks a node for XLA compilation.
inline constexpr absl::string_view kMustCompileAttr = "_XlaMustCompile";
// Marks a node for replication. The attribute value indicates the replication
// metadata op.
inline constexpr absl::string_view kReplicationInfoAttr = "_replication_info";
// Marks a node for XLA-TPU compilation. The attribute value indicates the
// associated compilation cluster and replication metadata op.
inline constexpr absl::string_view kTpuReplicateAttr = "_tpu_replicate";
// Marks a node inside of an XLA compilation cluster to be placed outside of the
// cluster.
inline constexpr absl::string_view kXlaOutsideCompilationAttr =
"_xla_outside_compilation";
// Frontend attributes ID.
inline constexpr absl::string_view kXlaFrontendAttributesAttrName =
"_XlaFrontendAttributes";
// Device types.
inline constexpr absl::string_view kDeviceAttr = "device";
inline constexpr absl::string_view kCpuDevice = "CPU";
inline constexpr absl::string_view kGpuDevice = "GPU";
inline constexpr absl::string_view kTpuDevice = "TPU";
inline constexpr absl::string_view kEmptyDevice = "";
// Device type may be empty in ops such as TF.PartitionedCall.
inline constexpr std::array<absl::string_view, 4> kValidDeviceTypes = {
kCpuDevice, kGpuDevice, kTpuDevice, kEmptyDevice};
// Attributes that need to be propagated during rewrites (e.g., in
// functionalization).
inline constexpr std::array<absl::string_view, 5> kAttrsToPropagate = {
kCompileDeviceTypeAttr,
kReplicationInfoAttr,
kXlaFrontendAttributesAttrName,
kXlaOutsideCompilationAttr,
kTpuReplicateAttr,
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_DEFS_H_