-
Notifications
You must be signed in to change notification settings - Fork 63
Kanvi/fused BMMT support #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sayantan-nervana
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
General comments:
Since this translate op will be heavily used, lets bump up its error messages:
For example:
if (ng_lhs_shape.size() != ng_rhs_shape.size()) {
return errors::InvalidArgument(
"Dimensions of two input args are not the same for BatchMatMul");
}
could become:
if (ng_lhs_shape.size() != ng_rhs_shape.size()) {
return errors::InvalidArgument(
"Dimensions of two input args are not the same for BatchMatMul. Left shape is ", ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), " and right shape is ", ng::join(ng_rhs_shape), " of rank ", ng_rhs_shape.size());
}
similarly for ""Dimensions of input args for BatchMatMul must be >=2", n_dims);" we can say : "Dimensions of input args for BatchMatMul must be >=2 but is ", n_dims);
and:
"ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul "
"for each dimension ",
i)
could be:
"ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul "
"for each dimension but found ",
i, "th dimension different.", ) //and print teh left shape and rigth shape using ng::join()
| // Get the backend name, if the backend is CPU and n_dims >= 3 | ||
| // then use the BatchMatMul op supported by nGraph | ||
| if (n_dims >= 3 && backend_name == "CPU") { | ||
| if (n_dims == 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we do not need the else parts:
else {
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs_axes.push_back(n_dims - 1);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say we can construct ng_lhs_axes and ng_rhs_axes in each of the cases using iota instead of push_back (which is slower probably) and get rid of out_axes
for example in case n>3, only 1 needs to be constructed. we do not need both ng_lhs_axes and ng_rhs_axes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Computing ng_lhs_shape = ng_lhs->get_shape(); is not needed in n==2 case. There is some leftover/legacy code from the previous implementation that we can clean up
ng_lhs = ng::builder::numpy_transpose(ng_lhs, tf_adj_x ? {1,0} : {0,1}); // or ng::AxisVector{0,1}
Builder::SetTracingInfo(op->name(), ng_lhs);
ng_rhs = ng::builder::numpy_transpose(ng_lhs, tf_adj_y ? {1,0} : {0,1});
Builder::SetTracingInfo(op->name(), ng_rhs);
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ng_lhs_axes.push_back(n_dims - 1); | ||
| ng_lhs_axes.push_back(n_dims - 2); | ||
| ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); | ||
| ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these 2 are not needed/used for n_dims == 2
ng_lhs_shape = ng_lhs->get_shape();
ng_rhs_shape = ng_rhs->get_shape();
sayantan-nervana
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
NGraph now supports a fused BatchMatMulTranspose. This PR is make changes in the translation for BatchMatMul to accommodate that.