diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index fce3509da..05b3baeb4 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -786,23 +786,26 @@ static Status TranslateBatchMatMulOp( if (ng_lhs_shape.size() != ng_rhs_shape.size()) { return errors::InvalidArgument( - "Dimensions of two input args are not the same for BatchMatMul"); + "Dimensions of two input args are not the same for BatchMatMul. Left " + "shape is ", + ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), + " and Right shape is ", ng::join(ng_rhs_shape), " of rank ", + ng_rhs_shape.size()); } size_t n_dims = ng_lhs_shape.size(); if (n_dims < 2) { return errors::InvalidArgument( - "Dimensions of input args for BatchMatMul must be >=2", n_dims); + "Dimensions of input args for BatchMatMul must be >=2 but is ", n_dims); } - ng::AxisVector out_axes; for (size_t i = 0; i < n_dims - 2; ++i) { if (ng_lhs_shape[i] != ng_rhs_shape[i]) { return errors::InvalidArgument( "ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul " - "for each dimension", - i); + "for each dimension but found ", + i, "th dimension different. Left shape is ", ng::join(ng_lhs_shape), + "and Right shape is ", ng::join(ng_rhs_shape)); } - out_axes.push_back(i); } bool tf_adj_x = false; @@ -810,146 +813,54 @@ static Status TranslateBatchMatMulOp( TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_x", &tf_adj_x)); TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_y", &tf_adj_y)); - auto ng_lhs_axes = out_axes; - auto ng_rhs_axes = out_axes; - - // Get the backend name, if the backend is CPU and n_dims >= 3 - // then use the BatchMatMul op supported by nGraph - if (n_dims >= 3 && backend_name == "CPU") { + if (n_dims == 2) { // Transpose X if AdjX = true if (tf_adj_x) { - ng_lhs_axes.push_back(n_dims - 1); - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); + ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0}); Builder::SetTracingInfo(op->name(), ng_lhs); - ng_lhs_shape = ng_lhs->get_shape(); - } else { - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs_axes.push_back(n_dims - 1); } // Transpose Y if AdjY = true if (tf_adj_y) { - ng_rhs_axes.push_back(n_dims - 1); - ng_rhs_axes.push_back(n_dims - 2); - ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); + ng_rhs = ng::builder::numpy_transpose(ng_rhs, {1, 0}); Builder::SetTracingInfo(op->name(), ng_rhs); - ng_rhs_shape = ng_rhs->get_shape(); - } else { - ng_rhs_axes.push_back(n_dims - 2); - ng_rhs_axes.push_back(n_dims - 1); - } - - if (n_dims == 3) { - SaveNgOp(ng_op_map, op->name(), ConstructNgNode( - op->name(), ng_lhs, ng_rhs)); - } else { - // Find the compound size for dim1 so as to reshape to 3D - size_t compound_size = 1; - for (size_t i = 0; i < out_axes.size(); i++) { - compound_size *= ng_lhs_shape[i]; - } - - ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2], - ng_lhs_shape[n_dims - 1]}; - ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2], - ng_rhs_shape[n_dims - 1]}; - - auto output_shape = ng_lhs_shape; - output_shape[n_dims - 1] = ng_rhs_shape[n_dims - 1]; - ng::AxisVector tmp_axes = {0, 1, 2}; - - std::shared_ptr lhs_reshape = - ConstructNgNode(op->name(), ng_lhs, ng_lhs_axes, - tmp_lhs_shape); - std::shared_ptr rhs_reshape = - ConstructNgNode(op->name(), ng_rhs, ng_rhs_axes, - tmp_rhs_shape); - std::shared_ptr batchmatmul = - ConstructNgNode(op->name(), lhs_reshape, - rhs_reshape); - SaveNgOp(ng_op_map, op->name(), - ConstructNgNode(op->name(), batchmatmul, - tmp_axes, output_shape)); } + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode(op->name(), ng_lhs, ng_rhs)); + } else if (n_dims == 3) { + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode( + op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y)); } else { - if (tf_adj_x) { - ng_lhs_axes.push_back(n_dims - 1); - ng_lhs_axes.push_back(n_dims - 2); - ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); - Builder::SetTracingInfo(op->name(), ng_lhs); - } - if (tf_adj_y) { - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); - ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); - Builder::SetTracingInfo(op->name(), ng_rhs); - } else { - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); - ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); - ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); - Builder::SetTracingInfo(op->name(), ng_rhs); - } - - ng_lhs_shape = ng_lhs->get_shape(); - ng_rhs_shape = ng_rhs->get_shape(); + ng::AxisVector out_axes(n_dims); + std::iota(out_axes.begin(), out_axes.end(), 0); - if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) { - return errors::InvalidArgument( - "The last dimension of ng_lhs and the first dimension of ng_rhs " - "should have the same size"); + size_t compound_size = 1; + for (size_t i = 0; i < n_dims - 2; i++) { + compound_size *= ng_lhs_shape[i]; } - if (n_dims == 2) { - SaveNgOp(ng_op_map, op->name(), - ConstructNgNode(op->name(), ng_lhs, ng_rhs)); - } else { - auto output_shape = ng_lhs_shape; - output_shape[n_dims - 1] = ng_rhs_shape[1]; - auto dot_output = - ConstructNgNode(op->name(), ng_lhs, ng_rhs); - - size_t compound_size = 1; - for (size_t i = 0; i < out_axes.size(); i++) { - compound_size *= output_shape[i]; - } - auto dot_axes = out_axes; - dot_axes.push_back(n_dims - 2); - dot_axes.push_back(n_dims - 1); - for (size_t i = 0; i < out_axes.size(); i++) { - dot_axes.push_back(n_dims + i); - } - ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2], - ng_rhs_shape[1], compound_size}; - std::shared_ptr dot_reshape; - if (n_dims == 3) { - dot_reshape = dot_output; - } else { - dot_reshape = ConstructNgNode( - op->name(), dot_output, dot_axes, dot_shape); - } - ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]}; - vector> tmp_tensors; - for (size_t i = 0; i < dot_shape[0]; i++) { - const std::vector lower_bound{i, 0, 0, i}; - const std::vector upper_bound{i + 1, dot_shape[1], dot_shape[2], - i + 1}; - auto slice_out = ConstructNgNode( - op->name(), dot_reshape, lower_bound, upper_bound); - auto reshape_out = ConstructNgNode( - op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape); - tmp_tensors.push_back(reshape_out); - } - auto concat_op = - ConstructNgNode(op->name(), tmp_tensors, 0); - if (n_dims == 3) { - SaveNgOp(ng_op_map, op->name(), concat_op); - } else { - SaveNgOp( - ng_op_map, op->name(), - ConstructNgNode( - op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape)); - } - } + ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2], + ng_lhs_shape[n_dims - 1]}; + ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2], + ng_rhs_shape[n_dims - 1]}; + + auto output_shape = ng_lhs_shape; + output_shape[n_dims - 2] = ng_lhs_shape[n_dims - (tf_adj_x ? 1 : 2)]; + output_shape[n_dims - 1] = ng_rhs_shape[n_dims - (tf_adj_y ? 2 : 1)]; + ng::AxisVector tmp_axes = {0, 1, 2}; + + std::shared_ptr lhs_reshape = + ConstructNgNode(op->name(), ng_lhs, out_axes, + tmp_lhs_shape); + std::shared_ptr rhs_reshape = + ConstructNgNode(op->name(), ng_rhs, out_axes, + tmp_rhs_shape); + std::shared_ptr batchmatmul_transpose = + ConstructNgNode( + op->name(), lhs_reshape, rhs_reshape, tf_adj_x, tf_adj_y); + SaveNgOp(ng_op_map, op->name(), + ConstructNgNode( + op->name(), batchmatmul_transpose, tmp_axes, output_shape)); } return Status::OK(); } @@ -966,7 +877,11 @@ static Status TranslateBatchMatMulV2Op( if (ng_lhs_shape.size() != ng_rhs_shape.size()) { return errors::InvalidArgument( - "Dimensions of two input args are not the same for BatchMatMul"); + "Dimensions of two input args are not the same for BatchMatMul. Left " + "shape is ", + ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), + " and Right shape is ", ng::join(ng_rhs_shape), " of rank ", + ng_rhs_shape.size()); } for (size_t i = 0; i < n_dims - 2; ++i) {