-
Notifications
You must be signed in to change notification settings - Fork 63
Kanvi/fused BMMT support #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
4459d99
Initial changes for using the fused BMMT
kanvi-nervana 57973b9
Make modifications to support >3 rank size
kanvi-nervana 20be96c
Merge branch 'master' into Kanvi/Fused_BMMT_support
kanvi-nervana 7e8ddbb
Merge branch 'master' into Kanvi/Fused_BMMT_support
kanvi-nervana 47f83f3
Merge branch 'master' into Kanvi/Fused_BMMT_support
kanvi-nervana 25ad660
Address comments
kanvi-nervana 306c25c
Merge branch 'Kanvi/Fused_BMMT_support' of https://github.com/tensorf…
kanvi-nervana e05bdaf
more comments addressed
kanvi-nervana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -786,170 +786,81 @@ static Status TranslateBatchMatMulOp( | |
|
|
||
| if (ng_lhs_shape.size() != ng_rhs_shape.size()) { | ||
| return errors::InvalidArgument( | ||
| "Dimensions of two input args are not the same for BatchMatMul"); | ||
| "Dimensions of two input args are not the same for BatchMatMul. Left " | ||
| "shape is ", | ||
| ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), | ||
| " and Right shape is ", ng::join(ng_rhs_shape), " of rank ", | ||
| ng_rhs_shape.size()); | ||
| } | ||
| size_t n_dims = ng_lhs_shape.size(); | ||
| if (n_dims < 2) { | ||
| return errors::InvalidArgument( | ||
| "Dimensions of input args for BatchMatMul must be >=2", n_dims); | ||
| "Dimensions of input args for BatchMatMul must be >=2 but is ", n_dims); | ||
| } | ||
|
|
||
| ng::AxisVector out_axes; | ||
| for (size_t i = 0; i < n_dims - 2; ++i) { | ||
| if (ng_lhs_shape[i] != ng_rhs_shape[i]) { | ||
| return errors::InvalidArgument( | ||
| "ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul " | ||
| "for each dimension", | ||
| i); | ||
| "for each dimension but found ", | ||
| i, "th dimension different. Left shape is ", ng::join(ng_lhs_shape), | ||
| "and Right shape is ", ng::join(ng_rhs_shape)); | ||
| } | ||
| out_axes.push_back(i); | ||
| } | ||
|
|
||
| bool tf_adj_x = false; | ||
| bool tf_adj_y = false; | ||
| TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_x", &tf_adj_x)); | ||
| TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_y", &tf_adj_y)); | ||
|
|
||
| auto ng_lhs_axes = out_axes; | ||
| auto ng_rhs_axes = out_axes; | ||
|
|
||
| // Get the backend name, if the backend is CPU and n_dims >= 3 | ||
| // then use the BatchMatMul op supported by nGraph | ||
| if (n_dims >= 3 && backend_name == "CPU") { | ||
| if (n_dims == 2) { | ||
| // Transpose X if AdjX = true | ||
| if (tf_adj_x) { | ||
| ng_lhs_axes.push_back(n_dims - 1); | ||
| ng_lhs_axes.push_back(n_dims - 2); | ||
| ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); | ||
| ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0}); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these 2 are not needed/used for n_dims == 2 |
||
| Builder::SetTracingInfo(op->name(), ng_lhs); | ||
| ng_lhs_shape = ng_lhs->get_shape(); | ||
| } else { | ||
| ng_lhs_axes.push_back(n_dims - 2); | ||
| ng_lhs_axes.push_back(n_dims - 1); | ||
| } | ||
| // Transpose Y if AdjY = true | ||
| if (tf_adj_y) { | ||
| ng_rhs_axes.push_back(n_dims - 1); | ||
| ng_rhs_axes.push_back(n_dims - 2); | ||
| ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); | ||
| ng_rhs = ng::builder::numpy_transpose(ng_rhs, {1, 0}); | ||
| Builder::SetTracingInfo(op->name(), ng_rhs); | ||
| ng_rhs_shape = ng_rhs->get_shape(); | ||
| } else { | ||
| ng_rhs_axes.push_back(n_dims - 2); | ||
| ng_rhs_axes.push_back(n_dims - 1); | ||
| } | ||
|
|
||
| if (n_dims == 3) { | ||
| SaveNgOp(ng_op_map, op->name(), ConstructNgNode<ngraph::op::BatchMatMul>( | ||
| op->name(), ng_lhs, ng_rhs)); | ||
| } else { | ||
| // Find the compound size for dim1 so as to reshape to 3D | ||
| size_t compound_size = 1; | ||
| for (size_t i = 0; i < out_axes.size(); i++) { | ||
| compound_size *= ng_lhs_shape[i]; | ||
| } | ||
|
|
||
| ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2], | ||
| ng_lhs_shape[n_dims - 1]}; | ||
| ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2], | ||
| ng_rhs_shape[n_dims - 1]}; | ||
|
|
||
| auto output_shape = ng_lhs_shape; | ||
| output_shape[n_dims - 1] = ng_rhs_shape[n_dims - 1]; | ||
| ng::AxisVector tmp_axes = {0, 1, 2}; | ||
|
|
||
| std::shared_ptr<ng::Node> lhs_reshape = | ||
| ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_lhs, ng_lhs_axes, | ||
| tmp_lhs_shape); | ||
| std::shared_ptr<ng::Node> rhs_reshape = | ||
| ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_rhs, ng_rhs_axes, | ||
| tmp_rhs_shape); | ||
| std::shared_ptr<ng::Node> batchmatmul = | ||
| ConstructNgNode<ngraph::op::BatchMatMul>(op->name(), lhs_reshape, | ||
| rhs_reshape); | ||
| SaveNgOp(ng_op_map, op->name(), | ||
| ConstructNgNode<ngraph::op::Reshape>(op->name(), batchmatmul, | ||
| tmp_axes, output_shape)); | ||
| } | ||
| SaveNgOp(ng_op_map, op->name(), | ||
| ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs)); | ||
| } else if (n_dims == 3) { | ||
| SaveNgOp(ng_op_map, op->name(), | ||
| ConstructNgNode<ngraph::op::BatchMatMulTranspose>( | ||
| op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y)); | ||
| } else { | ||
| if (tf_adj_x) { | ||
| ng_lhs_axes.push_back(n_dims - 1); | ||
| ng_lhs_axes.push_back(n_dims - 2); | ||
| ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); | ||
| Builder::SetTracingInfo(op->name(), ng_lhs); | ||
| } | ||
| if (tf_adj_y) { | ||
| ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); | ||
| ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); | ||
| ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); | ||
| Builder::SetTracingInfo(op->name(), ng_rhs); | ||
| } else { | ||
| ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); | ||
| ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); | ||
| ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); | ||
| Builder::SetTracingInfo(op->name(), ng_rhs); | ||
| } | ||
|
|
||
| ng_lhs_shape = ng_lhs->get_shape(); | ||
| ng_rhs_shape = ng_rhs->get_shape(); | ||
| ng::AxisVector out_axes(n_dims); | ||
| std::iota(out_axes.begin(), out_axes.end(), 0); | ||
|
|
||
| if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) { | ||
| return errors::InvalidArgument( | ||
| "The last dimension of ng_lhs and the first dimension of ng_rhs " | ||
| "should have the same size"); | ||
| size_t compound_size = 1; | ||
| for (size_t i = 0; i < n_dims - 2; i++) { | ||
| compound_size *= ng_lhs_shape[i]; | ||
| } | ||
|
|
||
| if (n_dims == 2) { | ||
| SaveNgOp(ng_op_map, op->name(), | ||
| ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs)); | ||
| } else { | ||
| auto output_shape = ng_lhs_shape; | ||
| output_shape[n_dims - 1] = ng_rhs_shape[1]; | ||
| auto dot_output = | ||
| ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs); | ||
|
|
||
| size_t compound_size = 1; | ||
| for (size_t i = 0; i < out_axes.size(); i++) { | ||
| compound_size *= output_shape[i]; | ||
| } | ||
| auto dot_axes = out_axes; | ||
| dot_axes.push_back(n_dims - 2); | ||
| dot_axes.push_back(n_dims - 1); | ||
| for (size_t i = 0; i < out_axes.size(); i++) { | ||
| dot_axes.push_back(n_dims + i); | ||
| } | ||
| ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2], | ||
| ng_rhs_shape[1], compound_size}; | ||
| std::shared_ptr<ng::Node> dot_reshape; | ||
| if (n_dims == 3) { | ||
| dot_reshape = dot_output; | ||
| } else { | ||
| dot_reshape = ConstructNgNode<ngraph::op::Reshape>( | ||
| op->name(), dot_output, dot_axes, dot_shape); | ||
| } | ||
| ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]}; | ||
| vector<shared_ptr<ngraph::Node>> tmp_tensors; | ||
| for (size_t i = 0; i < dot_shape[0]; i++) { | ||
| const std::vector<size_t> lower_bound{i, 0, 0, i}; | ||
| const std::vector<size_t> upper_bound{i + 1, dot_shape[1], dot_shape[2], | ||
| i + 1}; | ||
| auto slice_out = ConstructNgNode<ngraph::op::Slice>( | ||
| op->name(), dot_reshape, lower_bound, upper_bound); | ||
| auto reshape_out = ConstructNgNode<ngraph::op::Reshape>( | ||
| op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape); | ||
| tmp_tensors.push_back(reshape_out); | ||
| } | ||
| auto concat_op = | ||
| ConstructNgNode<ngraph::op::Concat>(op->name(), tmp_tensors, 0); | ||
| if (n_dims == 3) { | ||
| SaveNgOp(ng_op_map, op->name(), concat_op); | ||
| } else { | ||
| SaveNgOp( | ||
| ng_op_map, op->name(), | ||
| ConstructNgNode<ngraph::op::Reshape>( | ||
| op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape)); | ||
| } | ||
| } | ||
| ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2], | ||
| ng_lhs_shape[n_dims - 1]}; | ||
| ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2], | ||
| ng_rhs_shape[n_dims - 1]}; | ||
|
|
||
| auto output_shape = ng_lhs_shape; | ||
| output_shape[n_dims - 2] = ng_lhs_shape[n_dims - (tf_adj_x ? 1 : 2)]; | ||
| output_shape[n_dims - 1] = ng_rhs_shape[n_dims - (tf_adj_y ? 2 : 1)]; | ||
| ng::AxisVector tmp_axes = {0, 1, 2}; | ||
|
|
||
| std::shared_ptr<ng::Node> lhs_reshape = | ||
| ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_lhs, out_axes, | ||
| tmp_lhs_shape); | ||
| std::shared_ptr<ng::Node> rhs_reshape = | ||
| ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_rhs, out_axes, | ||
| tmp_rhs_shape); | ||
| std::shared_ptr<ng::Node> batchmatmul_transpose = | ||
| ConstructNgNode<ngraph::op::BatchMatMulTranspose>( | ||
| op->name(), lhs_reshape, rhs_reshape, tf_adj_x, tf_adj_y); | ||
| SaveNgOp(ng_op_map, op->name(), | ||
| ConstructNgNode<ngraph::op::Reshape>( | ||
| op->name(), batchmatmul_transpose, tmp_axes, output_shape)); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
@@ -966,7 +877,11 @@ static Status TranslateBatchMatMulV2Op( | |
|
|
||
| if (ng_lhs_shape.size() != ng_rhs_shape.size()) { | ||
| return errors::InvalidArgument( | ||
| "Dimensions of two input args are not the same for BatchMatMul"); | ||
| "Dimensions of two input args are not the same for BatchMatMul. Left " | ||
| "shape is ", | ||
| ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(), | ||
| " and Right shape is ", ng::join(ng_rhs_shape), " of rank ", | ||
| ng_rhs_shape.size()); | ||
| } | ||
|
|
||
| for (size_t i = 0; i < n_dims - 2; ++i) { | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we do not need the else parts:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say we can construct
ng_lhs_axesandng_rhs_axesin each of the cases using iota instead of push_back (which is slower probably) and get rid ofout_axesfor example in case n>3, only 1 needs to be constructed. we do not need both
ng_lhs_axesandng_rhs_axesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Computing
ng_lhs_shape = ng_lhs->get_shape();is not needed in n==2 case. There is some leftover/legacy code from the previous implementation that we can clean upThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done