Skip to content

Commit 16cdb25

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[colocation_graph] Follow only DT_VARIANT edges to find the underlying DT_VARIANT data type
PiperOrigin-RevId: 339253107 Change-Id: I3d7c231ada3216bfecb94bb15442d5717615f1b1
1 parent b0784e5 commit 16cdb25

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tensorflow/core/common_runtime/colocation_graph.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,14 @@ Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
805805
absl::optional<bool> is_host_data_type;
806806

807807
auto edge_filter = [&](const Edge& edge) -> bool {
808-
return !is_host_data_type.has_value();
808+
// We already found the underlying data type.
809+
if (is_host_data_type.has_value()) return false;
810+
811+
// Otherwise follow only DT_VARIANT data edges.
812+
auto edge_dtype = [&]() -> DataType {
813+
return edge.src()->output_type(edge.src_output());
814+
};
815+
return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
809816
};
810817

811818
auto enter = [&](Node* n) -> void {

0 commit comments

Comments
 (0)