Skip to content

Commit

Permalink
Reverts changelist 570504216
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625769893
  • Loading branch information
jcai19 authored and tensorflower-gardener committed Apr 17, 2024
1 parent 7bfbad8 commit c98067c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 56 deletions.
Expand Up @@ -76,6 +76,36 @@ LogicalResult HasAttr(
return failure();
}

// Check if the `graph` has parameter server jobs and resource variable
// arguments that are on parameter servers
bool HasPsWithResourceVariable(const Graph& graph) {
// Check parameter serverjobs and resource variable arguments that are
// on parameter servers.
const std::string jobType = "ps";
const std::string nodeType = "_Arg";
const std::string attrKey = "T";
for (const Node* node : graph.nodes()) {
if (node->type_string() == nodeType) {
auto device_name = node->assigned_device_name();
DeviceNameUtils::ParsedName device;
if (DeviceNameUtils::ParseFullName(device_name, &device) &&
device.has_job && device.job == jobType) {
for (const auto& attr : node->attrs()) {
auto attr_key = attr.first;
auto attr_value = attr.second;
if (attr_key == attrKey &&
attr_value.value_case() == AttrValue::kType &&
attr_value.type() == DT_RESOURCE) {
return true;
break;
}
}
}
}
}
return false;
}

bool IsNonReplicatedGraph(const Graph& graph,
const FunctionLibraryDefinition* function_library) {
auto predicate = [](const Graph& graph) {
Expand Down Expand Up @@ -171,7 +201,8 @@ bool AreFunctionsFromFlibDefInference(

bool IsSupportedByNonReplicatedBridge(
const Graph& graph, const FunctionLibraryDefinition* function_library) {
return IsNonReplicatedGraph(graph, function_library);
return IsNonReplicatedGraph(graph, function_library) &&
HasPsWithResourceVariable(graph);
}

bool IsSupportedByReplicatedBridge(
Expand Down
Expand Up @@ -44,28 +44,19 @@ namespace tensorflow {

namespace {

FunctionDef OuterXTimesTwo() {
// Produce a valid graph with a resource-type input.
FunctionDef PassThroughResource() {
return FunctionDefHelper::Define(
// Name
"OuterXTimesTwo",
// Args
{"x: float"},
// Return values
{"y: float"},
// Attr def
{},
{{{"y"},
"StatefulPartitionedCall",
{"x"},
{{"Tin", DataTypeSlice{DT_FLOAT}},
{"Tout", DataTypeSlice{DT_FLOAT}},
{"f",
FunctionDefHelper::FunctionRef("XTimesTwoFloat", {{"T", DT_FLOAT}})},
{std::string(kMustCompileAttr), true}}}});
/*function_name=*/"PassThroughResource",
/*arg_def=*/{"in: resource"},
/*ret_def=*/{"out: resource"},
/*attr_def=*/{},
/*node_def=*/
{{{"out"}, "Identity", {"in"}, {{"T", DataType::DT_RESOURCE}}}});
}

TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) {
const FunctionDef& fd = test::function::XTimesTwo();
const FunctionDef& fd = PassThroughResource();
FunctionDefLibrary flib;
*flib.add_function() = fd;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
Expand All @@ -76,7 +67,7 @@ TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) {
ConfigProto config = ConfigProto();
Scope root = Scope::NewRootScope().ExitOnError();

Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0);
Output a = ops::_Arg(root.WithOpName("A"), DT_RESOURCE, 0);
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});

Node* call;
Expand All @@ -85,50 +76,21 @@ TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) {
TF_ASSERT_OK(
NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def())
.Input(inputs)
.Attr("Tin", {DT_FLOAT})
.Attr("Tout", {DT_FLOAT})
.Attr("Tin", {DT_RESOURCE})
.Attr("Tout", {DT_RESOURCE})
.Attr("f", f_name_attr)
.Finalize(root.graph(), &call));
call->AddAttr(std::string(kMustCompileAttr), true);

TF_ASSERT_OK(root.ToGraph(&graph));

EXPECT_TRUE(
IsSupportedByNonReplicatedBridge(graph, /*function_library=*/nullptr));
}

// Checks that HasAttr actually goes through function library.
TEST(IsSupportedByNonReplicatedBridge, NonReplicatedFunctionLibrary) {
const FunctionDef& fd = OuterXTimesTwo();
FunctionDefLibrary flib;
*flib.add_function() = fd;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
Graph graph(OpRegistry::Global());
graph.SetConstructionContext(ConstructionContext::kEagerRuntime);
tensorflow::set_tf2_execution(true);

ConfigProto config = ConfigProto();
Scope root = Scope::NewRootScope().ExitOnError();

Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0);
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});

// Builds a call without compilation markers that calls a function with Xla
// clusters.
Node* call;
NameAttrList f_name_attr;
f_name_attr.set_name(fd.signature().name());
TF_ASSERT_OK(
NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def())
.Input(inputs)
.Attr("Tin", {DT_FLOAT})
.Attr("Tout", {DT_FLOAT})
.Attr("f", f_name_attr)
.Finalize(root.graph(), &call));
// Required for passing the PS server parameter check.
for (Node* node : graph.nodes()) {
node->set_assigned_device_name("/job:ps/replica:0/task:0/device:GPU:0");
}

TF_ASSERT_OK(root.ToGraph(&graph));
EXPECT_TRUE(
IsSupportedByNonReplicatedBridge(graph, /*function_library=*/&flib_def));
IsSupportedByNonReplicatedBridge(graph, /*function_library=*/nullptr));
}

TEST(IsSupportedByReplicatedBridge, ReplicatedGraph) {
Expand Down

0 comments on commit c98067c

Please sign in to comment.