Skip to content

Commit

Permalink
add label for xlaop
Browse files Browse the repository at this point in the history
  • Loading branch information
Agoniii committed Nov 28, 2019
1 parent 3cfd65b commit 6a74e16
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
5 changes: 5 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/data_format_ops.cc
Expand Up @@ -143,6 +143,11 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
REGISTER_XLA_OP(
Name("DataFormatVecPermute").TypeConstraint("T", {DT_INT32, DT_INT64}),
DataFormatVecPermuteOp);
REGISTER_XLA_OP(
Name("DataFormatVecPermute")
.Label("host")
.TypeConstraint("T", {DT_INT32, DT_INT64}),
DataFormatVecPermuteOp);

} // namespace
} // namespace tensorflow
11 changes: 0 additions & 11 deletions tensorflow/compiler/tf2xla/xla_compiler.cc
Expand Up @@ -723,17 +723,6 @@ Status XlaCompiler::CompileFunction(

std::unique_ptr<Graph> graph = GetGraph(fbody);

// Clear the "_kernel" attribute if it is set to "host". This is used to
// indicate that a computation should happen on the host instead of the
// accelerator, but doesn't make sense in XLA.
const char* const kKernelAttr = "_kernel";
for (Node* n : graph->nodes()) {
string value;
if (TryGetNodeAttr(n->attrs(), kKernelAttr, &value) && value == "host") {
n->ClearAttr(kKernelAttr);
}
}

// _Arg and _Retval nodes don't exist in the stored subgraph for the function;
// they are added by the function body looked up. Therefore, they don't have
// core assignments here.
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_registry.cc
Expand Up @@ -61,6 +61,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
/* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x,
const OpRegistration& y) {
if (x.name != y.name) return true;
if (x.label != y.label) return true;
// The registrations refer to the same Op: ensures they are compatible and
// are restricted to different device whitelists.
if (x.compilation_only != y.compilation_only) {
Expand Down Expand Up @@ -256,6 +257,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
std::unique_ptr<KernelDef> kdef(new KernelDef);
kdef->set_op(op_registration->name);
kdef->set_device_type(backend.first);
kdef->set_label(op_registration->label);

// Constrain each type attribute to the intersection of:
// a) the types supported by the backend, and
Expand Down Expand Up @@ -539,6 +541,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
return *this;
}

XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Label(
absl::string_view label) {
registration_->label = string(label);
return *this;
}

std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_registry.h
Expand Up @@ -270,6 +270,8 @@ class XlaOpRegistry {
// operands and not their values.
bool is_metadata_op = false;

string label;

// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
Expand Down Expand Up @@ -350,6 +352,9 @@ class XlaOpRegistrationBuilder {
// operands and not their values.
XlaOpRegistrationBuilder& IsMetadataOp();

// Specifies a particular value for the "_kernel" attr.
XlaOpRegistrationBuilder& Label(absl::string_view label);

std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);

Expand Down

0 comments on commit 6a74e16

Please sign in to comment.