Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 50 additions & 135 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,170 +786,81 @@ static Status TranslateBatchMatMulOp(

if (ng_lhs_shape.size() != ng_rhs_shape.size()) {
return errors::InvalidArgument(
"Dimensions of two input args are not the same for BatchMatMul");
"Dimensions of two input args are not the same for BatchMatMul. Left "
"shape is ",
ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(),
" and Right shape is ", ng::join(ng_rhs_shape), " of rank ",
ng_rhs_shape.size());
}
size_t n_dims = ng_lhs_shape.size();
if (n_dims < 2) {
return errors::InvalidArgument(
"Dimensions of input args for BatchMatMul must be >=2", n_dims);
"Dimensions of input args for BatchMatMul must be >=2 but is ", n_dims);
}

ng::AxisVector out_axes;
for (size_t i = 0; i < n_dims - 2; ++i) {
if (ng_lhs_shape[i] != ng_rhs_shape[i]) {
return errors::InvalidArgument(
"ng_lhs_shape and ng_rhs_shape must be the same for BatchMatMul "
"for each dimension",
i);
"for each dimension but found ",
i, "th dimension different. Left shape is ", ng::join(ng_lhs_shape),
"and Right shape is ", ng::join(ng_rhs_shape));
}
out_axes.push_back(i);
}

bool tf_adj_x = false;
bool tf_adj_y = false;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_x", &tf_adj_x));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "adj_y", &tf_adj_y));

auto ng_lhs_axes = out_axes;
auto ng_rhs_axes = out_axes;

// Get the backend name, if the backend is CPU and n_dims >= 3
// then use the BatchMatMul op supported by nGraph
if (n_dims >= 3 && backend_name == "CPU") {
if (n_dims == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we do not need the else parts:

else {
      ng_lhs_axes.push_back(n_dims - 2);
      ng_lhs_axes.push_back(n_dims - 1);
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say we can construct ng_lhs_axes and ng_rhs_axes in each of the cases using iota instead of push_back (which is slower probably) and get rid of out_axes

for example in case n>3, only 1 needs to be constructed. we do not need both ng_lhs_axes and ng_rhs_axes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing ng_lhs_shape = ng_lhs->get_shape(); is not needed in n==2 case. There is some leftover/legacy code from the previous implementation that we can clean up

ng_lhs = ng::builder::numpy_transpose(ng_lhs, tf_adj_x ? {1,0} : {0,1}); // or ng::AxisVector{0,1}
Builder::SetTracingInfo(op->name(), ng_lhs);
ng_rhs = ng::builder::numpy_transpose(ng_lhs, tf_adj_y ? {1,0} : {0,1});
Builder::SetTracingInfo(op->name(), ng_rhs);
SaveNgOp(ng_op_map, op->name(),
             ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Transpose X if AdjX = true
if (tf_adj_x) {
ng_lhs_axes.push_back(n_dims - 1);
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, {1, 0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 are not needed/used for n_dims == 2
ng_lhs_shape = ng_lhs->get_shape();
ng_rhs_shape = ng_rhs->get_shape();

Builder::SetTracingInfo(op->name(), ng_lhs);
ng_lhs_shape = ng_lhs->get_shape();
} else {
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs_axes.push_back(n_dims - 1);
}
// Transpose Y if AdjY = true
if (tf_adj_y) {
ng_rhs_axes.push_back(n_dims - 1);
ng_rhs_axes.push_back(n_dims - 2);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, {1, 0});
Builder::SetTracingInfo(op->name(), ng_rhs);
ng_rhs_shape = ng_rhs->get_shape();
} else {
ng_rhs_axes.push_back(n_dims - 2);
ng_rhs_axes.push_back(n_dims - 1);
}

if (n_dims == 3) {
SaveNgOp(ng_op_map, op->name(), ConstructNgNode<ngraph::op::BatchMatMul>(
op->name(), ng_lhs, ng_rhs));
} else {
// Find the compound size for dim1 so as to reshape to 3D
size_t compound_size = 1;
for (size_t i = 0; i < out_axes.size(); i++) {
compound_size *= ng_lhs_shape[i];
}

ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2],
ng_lhs_shape[n_dims - 1]};
ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2],
ng_rhs_shape[n_dims - 1]};

auto output_shape = ng_lhs_shape;
output_shape[n_dims - 1] = ng_rhs_shape[n_dims - 1];
ng::AxisVector tmp_axes = {0, 1, 2};

std::shared_ptr<ng::Node> lhs_reshape =
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_lhs, ng_lhs_axes,
tmp_lhs_shape);
std::shared_ptr<ng::Node> rhs_reshape =
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_rhs, ng_rhs_axes,
tmp_rhs_shape);
std::shared_ptr<ng::Node> batchmatmul =
ConstructNgNode<ngraph::op::BatchMatMul>(op->name(), lhs_reshape,
rhs_reshape);
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Reshape>(op->name(), batchmatmul,
tmp_axes, output_shape));
}
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
} else if (n_dims == 3) {
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::BatchMatMulTranspose>(
op->name(), ng_lhs, ng_rhs, tf_adj_x, tf_adj_y));
} else {
if (tf_adj_x) {
ng_lhs_axes.push_back(n_dims - 1);
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
Builder::SetTracingInfo(op->name(), ng_lhs);
}
if (tf_adj_y) {
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
Builder::SetTracingInfo(op->name(), ng_rhs);
} else {
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
Builder::SetTracingInfo(op->name(), ng_rhs);
}

ng_lhs_shape = ng_lhs->get_shape();
ng_rhs_shape = ng_rhs->get_shape();
ng::AxisVector out_axes(n_dims);
std::iota(out_axes.begin(), out_axes.end(), 0);

if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) {
return errors::InvalidArgument(
"The last dimension of ng_lhs and the first dimension of ng_rhs "
"should have the same size");
size_t compound_size = 1;
for (size_t i = 0; i < n_dims - 2; i++) {
compound_size *= ng_lhs_shape[i];
}

if (n_dims == 2) {
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
} else {
auto output_shape = ng_lhs_shape;
output_shape[n_dims - 1] = ng_rhs_shape[1];
auto dot_output =
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs);

size_t compound_size = 1;
for (size_t i = 0; i < out_axes.size(); i++) {
compound_size *= output_shape[i];
}
auto dot_axes = out_axes;
dot_axes.push_back(n_dims - 2);
dot_axes.push_back(n_dims - 1);
for (size_t i = 0; i < out_axes.size(); i++) {
dot_axes.push_back(n_dims + i);
}
ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2],
ng_rhs_shape[1], compound_size};
std::shared_ptr<ng::Node> dot_reshape;
if (n_dims == 3) {
dot_reshape = dot_output;
} else {
dot_reshape = ConstructNgNode<ngraph::op::Reshape>(
op->name(), dot_output, dot_axes, dot_shape);
}
ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]};
vector<shared_ptr<ngraph::Node>> tmp_tensors;
for (size_t i = 0; i < dot_shape[0]; i++) {
const std::vector<size_t> lower_bound{i, 0, 0, i};
const std::vector<size_t> upper_bound{i + 1, dot_shape[1], dot_shape[2],
i + 1};
auto slice_out = ConstructNgNode<ngraph::op::Slice>(
op->name(), dot_reshape, lower_bound, upper_bound);
auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape);
tmp_tensors.push_back(reshape_out);
}
auto concat_op =
ConstructNgNode<ngraph::op::Concat>(op->name(), tmp_tensors, 0);
if (n_dims == 3) {
SaveNgOp(ng_op_map, op->name(), concat_op);
} else {
SaveNgOp(
ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Reshape>(
op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape));
}
}
ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2],
ng_lhs_shape[n_dims - 1]};
ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2],
ng_rhs_shape[n_dims - 1]};

auto output_shape = ng_lhs_shape;
output_shape[n_dims - 2] = ng_lhs_shape[n_dims - (tf_adj_x ? 1 : 2)];
output_shape[n_dims - 1] = ng_rhs_shape[n_dims - (tf_adj_y ? 2 : 1)];
ng::AxisVector tmp_axes = {0, 1, 2};

std::shared_ptr<ng::Node> lhs_reshape =
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_lhs, out_axes,
tmp_lhs_shape);
std::shared_ptr<ng::Node> rhs_reshape =
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_rhs, out_axes,
tmp_rhs_shape);
std::shared_ptr<ng::Node> batchmatmul_transpose =
ConstructNgNode<ngraph::op::BatchMatMulTranspose>(
op->name(), lhs_reshape, rhs_reshape, tf_adj_x, tf_adj_y);
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Reshape>(
op->name(), batchmatmul_transpose, tmp_axes, output_shape));
}
return Status::OK();
}
Expand All @@ -966,7 +877,11 @@ static Status TranslateBatchMatMulV2Op(

if (ng_lhs_shape.size() != ng_rhs_shape.size()) {
return errors::InvalidArgument(
"Dimensions of two input args are not the same for BatchMatMul");
"Dimensions of two input args are not the same for BatchMatMul. Left "
"shape is ",
ng::join(ng_lhs_shape), " of rank ", ng_lhs_shape.size(),
" and Right shape is ", ng::join(ng_rhs_shape), " of rank ",
ng_rhs_shape.size());
}

for (size_t i = 0; i < n_dims - 2; ++i) {
Expand Down