From 4459d996e511417df410f1c23e3eb1c660b76713 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Fri, 20 Dec 2019 15:12:17 -0800 Subject: [PATCH 1/4] Initial changes for using the fused BMMT --- ngraph_bridge/ngraph_builder.cc | 93 +++------------------------------ 1 file changed, 8 insertions(+), 85 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 89b3271e3..7e886b835 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -812,9 +812,11 @@ static Status TranslateBatchMatMulOp( auto ng_lhs_axes = out_axes; auto ng_rhs_axes = out_axes; - // Get the backend name, if the backend is CPU and n_dims >= 3 - // then use the BatchMatMul op supported by nGraph - if (n_dims >= 3 && backend_name == "CPU") { + if (n_dims == 3) { + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode( + op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y)); + } else { // Transpose X if AdjX = true if (tf_adj_x) { ng_lhs_axes.push_back(n_dims - 1); @@ -838,9 +840,9 @@ static Status TranslateBatchMatMulOp( ng_rhs_axes.push_back(n_dims - 1); } - if (n_dims == 3) { - SaveNgOp(ng_op_map, op->name(), ConstructNgNode( - op->name(), ng_lhs, ng_rhs)); + if (n_dims == 2) { + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode(op->name(), ng_lhs, ng_rhs)); } else { // Find the compound size for dim1 so as to reshape to 3D size_t compound_size = 1; @@ -870,85 +872,6 @@ static Status TranslateBatchMatMulOp( ConstructNgNode(op->name(), batchmatmul, tmp_axes, output_shape)); } - } else { - if (tf_adj_x) { - ng_lhs_axes.push_back(n_dims - 1); - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); - Builder::SetTracingInfo(op->name(), ng_lhs); - } - if (tf_adj_y) { - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); - ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); - Builder::SetTracingInfo(op->name(), ng_rhs); - } else { - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); - ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); - Builder::SetTracingInfo(op->name(), ng_rhs); - } - - ng_lhs_shape = ng_lhs->get_shape(); - ng_rhs_shape = ng_rhs->get_shape(); - - if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) { - return errors::InvalidArgument( - "The last dimension of ng_lhs and the first dimension of ng_rhs " - "should have the same size"); - } - - if (n_dims == 2) { - SaveNgOp(ng_op_map, op->name(), - ConstructNgNode(op->name(), ng_lhs, ng_rhs)); - } else { - auto output_shape = ng_lhs_shape; - output_shape[n_dims - 1] = ng_rhs_shape[1]; - auto dot_output = - ConstructNgNode(op->name(), ng_lhs, ng_rhs); - - size_t compound_size = 1; - for (size_t i = 0; i < out_axes.size(); i++) { - compound_size *= output_shape[i]; - } - auto dot_axes = out_axes; - dot_axes.push_back(n_dims - 2); - dot_axes.push_back(n_dims - 1); - for (size_t i = 0; i < out_axes.size(); i++) { - dot_axes.push_back(n_dims + i); - } - ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2], - ng_rhs_shape[1], compound_size}; - std::shared_ptr dot_reshape; - if (n_dims == 3) { - dot_reshape = dot_output; - } else { - dot_reshape = ConstructNgNode( - op->name(), dot_output, dot_axes, dot_shape); - } - ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]}; - vector> tmp_tensors; - for (size_t i = 0; i < dot_shape[0]; i++) { - const std::vector lower_bound{i, 0, 0, i}; - const std::vector upper_bound{i + 1, dot_shape[1], dot_shape[2], - i + 1}; - auto slice_out = ConstructNgNode( - op->name(), dot_reshape, lower_bound, upper_bound); - auto reshape_out = ConstructNgNode( - op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape); - tmp_tensors.push_back(reshape_out); - } - auto concat_op = - ConstructNgNode(op->name(), tmp_tensors, 0); - if (n_dims == 3) { - SaveNgOp(ng_op_map, op->name(), concat_op); - } else { - SaveNgOp( - ng_op_map, op->name(), - ConstructNgNode( - op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape)); - } - } } return Status::OK(); } From 57973b9a8d2de9bfb139c48d58fd4fc040760439 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Tue, 24 Dec 2019 12:30:29 -0800 Subject: [PATCH 2/4] Make modifications to support >3 rank size --- ngraph_bridge/ngraph_builder.cc | 79 +++++++++++++++++---------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 7e886b835..9dd085213 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -798,7 +798,7 @@ static Status TranslateBatchMatMulOp( if (ng_lhs_shape[i] != ng_rhs_shape[i]) { return errors::InvalidArgument( "ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul " - "for each dimension", + "for each dimension ", i); } out_axes.push_back(i); @@ -811,12 +811,7 @@ static Status TranslateBatchMatMulOp( auto ng_lhs_axes = out_axes; auto ng_rhs_axes = out_axes; - - if (n_dims == 3) { - SaveNgOp(ng_op_map, op->name(), - ConstructNgNode( - op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y)); - } else { + if (n_dims == 2) { // Transpose X if AdjX = true if (tf_adj_x) { ng_lhs_axes.push_back(n_dims - 1); @@ -839,39 +834,45 @@ static Status TranslateBatchMatMulOp( ng_rhs_axes.push_back(n_dims - 2); ng_rhs_axes.push_back(n_dims - 1); } - - if (n_dims == 2) { - SaveNgOp(ng_op_map, op->name(), - ConstructNgNode(op->name(), ng_lhs, ng_rhs)); - } else { - // Find the compound size for dim1 so as to reshape to 3D - size_t compound_size = 1; - for (size_t i = 0; i < out_axes.size(); i++) { - compound_size *= ng_lhs_shape[i]; - } - - ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2], - ng_lhs_shape[n_dims - 1]}; - ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2], - ng_rhs_shape[n_dims - 1]}; - - auto output_shape = ng_lhs_shape; - output_shape[n_dims - 1] = ng_rhs_shape[n_dims - 1]; - ng::AxisVector tmp_axes = {0, 1, 2}; - - std::shared_ptr lhs_reshape = - ConstructNgNode(op->name(), ng_lhs, ng_lhs_axes, - tmp_lhs_shape); - std::shared_ptr rhs_reshape = - ConstructNgNode(op->name(), ng_rhs, ng_rhs_axes, - tmp_rhs_shape); - std::shared_ptr batchmatmul = - ConstructNgNode(op->name(), lhs_reshape, - rhs_reshape); - SaveNgOp(ng_op_map, op->name(), - ConstructNgNode(op->name(), batchmatmul, - tmp_axes, output_shape)); + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode(op->name(), ng_lhs, ng_rhs)); + } else if (n_dims == 3) { + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode( + op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y)); + } else { + ng_lhs_axes.push_back(n_dims - 2); + ng_lhs_axes.push_back(n_dims - 1); + ng_rhs_axes.push_back(n_dims - 2); + ng_rhs_axes.push_back(n_dims - 1); + + size_t compound_size = 1; + for (size_t i = 0; i < out_axes.size(); i++) { + compound_size *= ng_lhs_shape[i]; } + + ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2], + ng_lhs_shape[n_dims - 1]}; + ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2], + ng_rhs_shape[n_dims - 1]}; + + auto output_shape = ng_lhs_shape; + output_shape[n_dims - 2] = ng_lhs_shape[n_dims - (tf_adj_x ? 1 : 2)]; + output_shape[n_dims - 1] = ng_rhs_shape[n_dims - (tf_adj_y ? 2 : 1)]; + ng::AxisVector tmp_axes = {0, 1, 2}; + + std::shared_ptr lhs_reshape = + ConstructNgNode(op->name(), ng_lhs, ng_lhs_axes, + tmp_lhs_shape); + std::shared_ptr rhs_reshape = + ConstructNgNode(op->name(), ng_rhs, ng_rhs_axes, + tmp_rhs_shape); + std::shared_ptr batchmatmul_transpose = + ConstructNgNode( + op->name(), lhs_reshape, rhs_reshape, tf_adj_x, tf_adj_y); + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode( + op->name(), batchmatmul_transpose, tmp_axes, output_shape)); } return Status::OK(); } From 25ad660a1fc77c3b96aa2e5395757b573b15839f Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Tue, 31 Dec 2019 13:13:10 -0800 Subject: [PATCH 3/4] Address comments --- ngraph_bridge/ngraph_builder.cc | 49 ++++++++++++++------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 9ecbddc52..88f7074d4 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -786,23 +786,26 @@ static Status TranslateBatchMatMulOp( if (ng_lhs_shape.size() != ng_rhs_shape.size()) { return errors::InvalidArgument( - "Dimensions of two input args are not the same for BatchMatMul"); + "Dimensions of two input args are not the same for BatchMatMul. Left " + "shape is ", + ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), + " and Right shape is ", ng::join(ng_rhs_shape), " of rank ", + ng_rhs_shape.size()); } size_t n_dims = ng_lhs_shape.size(); if (n_dims < 2) { return errors::InvalidArgument( - "Dimensions of input args for BatchMatMul must be >=2", n_dims); + "Dimensions of input args for BatchMatMul must be >=2 but is ", n_dims); } - ng::AxisVector out_axes; for (size_t i = 0; i < n_dims - 2; ++i) { if (ng_lhs_shape[i] != ng_rhs_shape[i]) { return errors::InvalidArgument( "ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul " - "for each dimension ", - i); + "for each dimension but found ", + i, "th dimension different. Left shape is ", ng::join(ng_lhs_shape), + "and Right shape is ", ng::join(ng_rhs_shape)); } - out_axes.push_back(i); } bool tf_adj_x = false; @@ -810,30 +813,18 @@ static Status TranslateBatchMatMulOp( TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_x", &tf_adj_x)); TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_y", &tf_adj_y)); - auto ng_lhs_axes = out_axes; - auto ng_rhs_axes = out_axes; if (n_dims == 2) { // Transpose X if AdjX = true if (tf_adj_x) { - ng_lhs_axes.push_back(n_dims - 1); - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); + ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0}); Builder::SetTracingInfo(op->name(), ng_lhs); ng_lhs_shape = ng_lhs->get_shape(); - } else { - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs_axes.push_back(n_dims - 1); } // Transpose Y if AdjY = true if (tf_adj_y) { - ng_rhs_axes.push_back(n_dims - 1); - ng_rhs_axes.push_back(n_dims - 2); - ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); + ng_rhs = ng::builder::numpy_transpose(ng_rhs, {1, 0}); Builder::SetTracingInfo(op->name(), ng_rhs); ng_rhs_shape = ng_rhs->get_shape(); - } else { - ng_rhs_axes.push_back(n_dims - 2); - ng_rhs_axes.push_back(n_dims - 1); } SaveNgOp(ng_op_map, op->name(), ConstructNgNode(op->name(), ng_lhs, ng_rhs)); @@ -842,13 +833,11 @@ static Status TranslateBatchMatMulOp( ConstructNgNode( op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y)); } else { - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs_axes.push_back(n_dims - 1); - ng_rhs_axes.push_back(n_dims - 2); - ng_rhs_axes.push_back(n_dims - 1); + ng::AxisVector out_axes(n_dims); + std::iota(out_axes.begin(), out_axes.end(), 0); size_t compound_size = 1; - for (size_t i = 0; i < out_axes.size(); i++) { + for (size_t i = 0; i < n_dims - 2; i++) { compound_size *= ng_lhs_shape[i]; } @@ -863,10 +852,10 @@ static Status TranslateBatchMatMulOp( ng::AxisVector tmp_axes = {0, 1, 2}; std::shared_ptr lhs_reshape = - ConstructNgNode(op->name(), ng_lhs, ng_lhs_axes, + ConstructNgNode(op->name(), ng_lhs, out_axes, tmp_lhs_shape); std::shared_ptr rhs_reshape = - ConstructNgNode(op->name(), ng_rhs, ng_rhs_axes, + ConstructNgNode(op->name(), ng_rhs, out_axes, tmp_rhs_shape); std::shared_ptr batchmatmul_transpose = ConstructNgNode( @@ -890,7 +879,11 @@ static Status TranslateBatchMatMulV2Op( if (ng_lhs_shape.size() != ng_rhs_shape.size()) { return errors::InvalidArgument( - "Dimensions of two input args are not the same for BatchMatMul"); + "Dimensions of two input args are not the same for BatchMatMul. Left " + "shape is ", + ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), + " and Right shape is ", ng::join(ng_rhs_shape), " of rank ", + ng_rhs_shape.size()); } for (size_t i = 0; i < n_dims - 2; ++i) { From e05bdafa0be24a737890a066278b8264ee42c3bc Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Tue, 31 Dec 2019 14:09:28 -0800 Subject: [PATCH 4/4] more comments addressed --- ngraph_bridge/ngraph_builder.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 88f7074d4..05b3baeb4 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -818,13 +818,11 @@ static Status TranslateBatchMatMulOp( if (tf_adj_x) { ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0}); Builder::SetTracingInfo(op->name(), ng_lhs); - ng_lhs_shape = ng_lhs->get_shape(); } // Transpose Y if AdjY = true if (tf_adj_y) { ng_rhs = ng::builder::numpy_transpose(ng_rhs, {1, 0}); Builder::SetTracingInfo(op->name(), ng_rhs); - ng_rhs_shape = ng_rhs->get_shape(); } SaveNgOp(ng_op_map, op->name(), ConstructNgNode(op->name(), ng_lhs, ng_rhs));