We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b0784e5 commit 16cdb25Copy full SHA for 16cdb25
tensorflow/core/common_runtime/colocation_graph.cc
@@ -805,7 +805,14 @@ Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
805
absl::optional<bool> is_host_data_type;
806
807
auto edge_filter = [&](const Edge& edge) -> bool {
808
- return !is_host_data_type.has_value();
+ // We already found the underlying data type.
809
+ if (is_host_data_type.has_value()) return false;
810
+
811
+ // Otherwise follow only DT_VARIANT data edges.
812
+ auto edge_dtype = [&]() -> DataType {
813
+ return edge.src()->output_type(edge.src_output());
814
+ };
815
+ return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
816
};
817
818
auto enter = [&](Node* n) -> void {
0 commit comments