Skip to content

Conversation

@kanvi-nervana
Copy link
Contributor

NGraph now supports a fused BatchMatMulTranspose. This PR is make changes in the translation for BatchMatMul to accommodate that.

@kanvi-nervana kanvi-nervana changed the title Initial changes for using the fused BMMT Kanvi/fused BMMT support Dec 20, 2019
Copy link
Contributor

@sayantan-nervana sayantan-nervana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

General comments:

Since this translate op will be heavily used, lets bump up its error messages:

For example:
if (ng_lhs_shape.size() != ng_rhs_shape.size()) {
return errors::InvalidArgument(
"Dimensions of two input args are not the same for BatchMatMul");
}

could become:

if (ng_lhs_shape.size() != ng_rhs_shape.size()) {
return errors::InvalidArgument(
"Dimensions of two input args are not the same for BatchMatMul. Left shape is ", ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), " and right shape is ", ng::join(ng_rhs_shape), " of rank ", ng_rhs_shape.size());
}

similarly for ""Dimensions of input args for BatchMatMul must be >=2", n_dims);" we can say : "Dimensions of input args for BatchMatMul must be >=2 but is ", n_dims);

and:
"ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul "
"for each dimension ",
i)
could be:
"ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul "
"for each dimension but found ",
i, "th dimension different.", ) //and print teh left shape and rigth shape using ng::join()

// Get the backend name, if the backend is CPU and n_dims >= 3
// then use the BatchMatMul op supported by nGraph
if (n_dims >= 3 && backend_name == "CPU") {
if (n_dims == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we do not need the else parts:

else {
      ng_lhs_axes.push_back(n_dims - 2);
      ng_lhs_axes.push_back(n_dims - 1);
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say we can construct ng_lhs_axes and ng_rhs_axes in each of the cases using iota instead of push_back (which is slower probably) and get rid of out_axes

for example in case n>3, only 1 needs to be constructed. we do not need both ng_lhs_axes and ng_rhs_axes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing ng_lhs_shape = ng_lhs->get_shape(); is not needed in n==2 case. There is some leftover/legacy code from the previous implementation that we can clean up

ng_lhs = ng::builder::numpy_transpose(ng_lhs, tf_adj_x ? {1,0} : {0,1}); // or ng::AxisVector{0,1}
Builder::SetTracingInfo(op->name(), ng_lhs);
ng_rhs = ng::builder::numpy_transpose(ng_lhs, tf_adj_y ? {1,0} : {0,1});
Builder::SetTracingInfo(op->name(), ng_rhs);
SaveNgOp(ng_op_map, op->name(),
             ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ng_lhs_axes.push_back(n_dims - 1);
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 are not needed/used for n_dims == 2
ng_lhs_shape = ng_lhs->get_shape();
ng_rhs_shape = ng_rhs->get_shape();

Copy link
Contributor

@sayantan-nervana sayantan-nervana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sayantan-nervana sayantan-nervana added fully reviewed ready to merge This PR is the next in the queue. labels Dec 31, 2019
@sayantan-nervana sayantan-nervana merged commit fa0d491 into master Dec 31, 2019
@sayantan-nervana sayantan-nervana deleted the Kanvi/Fused_BMMT_support branch December 31, 2019 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fully reviewed ready to merge This PR is the next in the queue.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants