Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4269c80
Redo BatchMatMul since nGraph has the op supposrt for it now
kanvi-nervana May 10, 2019
36b0ead
Merge branch 'master' into kanvi/redo-batchmatmul
avijit-nervana May 10, 2019
2fe3036
Merge branch 'master' into kanvi/redo-batchmatmul
avijit-nervana May 10, 2019
5efa29e
Skip BatchMatMul3* tests for GPU since GPU does not support this op
kanvi-nervana May 10, 2019
20f966d
Changes to support nGraph BatchMatMul for CPU and the older way of tr…
kanvi-nervana May 10, 2019
512dc38
Fix formatting
kanvi-nervana May 10, 2019
c879a97
Merge branch 'master' into kanvi/redo-batchmatmul
kanvi-nervana May 10, 2019
e1a0aa3
Merge branch 'master' into kanvi/redo-batchmatmul
avijit-nervana May 13, 2019
d5e5b6c
Merge branch 'master' into kanvi/redo-batchmatmul
kanvi-nervana May 13, 2019
0c197d1
Merge branch 'kanvi/redo-batchmatmul' of https://github.com/tensorflo…
kanvi-nervana May 13, 2019
c3198a8
Merge branch 'master' into kanvi/redo-batchmatmul
avijit-nervana May 14, 2019
66e8485
Redo the translation for BatchMatMul so that it uses the CPU BatchMatMul
kanvi-nervana May 13, 2019
4eda798
Merge branch 'kanvi/redo-batchmatmul' of https://github.com/tensorflo…
kanvi-nervana May 14, 2019
58f91bf
Fix formatting
kanvi-nervana May 14, 2019
64e0213
Merge branch 'master' into kanvi/redo-batchmatmul
avijit-nervana May 14, 2019
11205ec
Reorganise the code to add the correct backend as attribute
kanvi-nervana May 15, 2019
cc4384e
Add logic to pass higher-order(>3) BatchMatMul to CPU as well.
kanvi-nervana May 15, 2019
2d2231f
Get the shape again in case of a transpose
kanvi-nervana May 15, 2019
3d57f39
Fix formatting
kanvi-nervana May 15, 2019
c769a3d
Update test/test_math_ops.cpp
sayantan-nervana May 16, 2019
05e9378
Merge branch 'master' into kanvi/redo-batchmatmul
sayantan-nervana May 16, 2019
3c0425d
Merge branch 'master' into kanvi/redo-batchmatmul
kanvi-nervana May 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 127 additions & 63 deletions src/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,9 @@ static Status TranslateBatchMatMulOp(
shared_ptr<ng::Node> ng_lhs, ng_rhs;
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_lhs, &ng_rhs));

std::string backend_name;
TF_RETURN_IF_ERROR(ngraph_bridge::GetNodeBackend(op, &backend_name));

auto ng_lhs_shape = ng_lhs->get_shape();
auto ng_rhs_shape = ng_rhs->get_shape();

Expand Down Expand Up @@ -781,77 +784,138 @@ static Status TranslateBatchMatMulOp(

auto ng_lhs_axes = out_axes;
auto ng_rhs_axes = out_axes;
if (tf_adj_x) {
ng_lhs_axes.push_back(n_dims - 1);
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
}
if (tf_adj_y) {
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
} else {
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
}

ng_lhs_shape = ng_lhs->get_shape();
ng_rhs_shape = ng_rhs->get_shape();

if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) {
return errors::InvalidArgument(
"The last dimension of ng_lhs and the first dimension of ng_rhs "
"should have the same size");
}
if (n_dims == 2) {
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
} else {
auto output_shape = ng_lhs_shape;
output_shape[n_dims - 1] = ng_rhs_shape[1];
auto dot_output =
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs);
size_t compound_size = 1;
for (int i = 0; i < out_axes.size(); i++) {
compound_size *= output_shape[i];
// Get the backend name, if the backend is CPU and n_dims >= 3
// then use the BatchMatMul op supported by nGraph
if (n_dims >= 3 && backend_name == "CPU") {
// Transpose X if AdjX = true
if (tf_adj_x) {
ng_lhs_axes.push_back(n_dims - 1);
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
ng_lhs_shape = ng_lhs->get_shape();
} else {
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs_axes.push_back(n_dims - 1);
}
auto dot_axes = out_axes;
dot_axes.push_back(n_dims - 2);
dot_axes.push_back(n_dims - 1);
for (int i = 0; i < out_axes.size(); i++) {
dot_axes.push_back(n_dims + i);
// Transpose Y if AdjY = true
if (tf_adj_y) {
ng_rhs_axes.push_back(n_dims - 1);
ng_rhs_axes.push_back(n_dims - 2);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
ng_rhs_shape = ng_rhs->get_shape();
} else {
ng_rhs_axes.push_back(n_dims - 2);
ng_rhs_axes.push_back(n_dims - 1);
}
ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2],
ng_rhs_shape[1], compound_size};
std::shared_ptr<ng::Node> dot_reshape;

if (n_dims == 3) {
dot_reshape = dot_output;
SaveNgOp(ng_op_map, op->name(), ConstructNgNode<ngraph::op::BatchMatMul>(
op->name(), ng_lhs, ng_rhs));
} else {
dot_reshape = ConstructNgNode<ngraph::op::Reshape>(op->name(), dot_output,
dot_axes, dot_shape);
// Find the compound size for dim1 so as to reshape to 3D
size_t compound_size = 1;
for (int i = 0; i < out_axes.size(); i++) {
compound_size *= ng_lhs_shape[i];
}

ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2],
ng_lhs_shape[n_dims - 1]};
ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2],
ng_rhs_shape[n_dims - 1]};

auto output_shape = ng_lhs_shape;
output_shape[n_dims - 1] = ng_rhs_shape[n_dims - 1];
ng::AxisVector tmp_axes = {0, 1, 2};

std::shared_ptr<ng::Node> lhs_reshape =
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_lhs, ng_lhs_axes,
tmp_lhs_shape);
std::shared_ptr<ng::Node> rhs_reshape =
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_rhs, ng_rhs_axes,
tmp_rhs_shape);
std::shared_ptr<ng::Node> batchmatmul =
ConstructNgNode<ngraph::op::BatchMatMul>(op->name(), lhs_reshape,
rhs_reshape);
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Reshape>(op->name(), batchmatmul,
tmp_axes, output_shape));
}
ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]};
vector<shared_ptr<ngraph::Node>> tmp_tensors;
for (size_t i = 0; i < dot_shape[0]; i++) {
const std::vector<size_t> lower_bound{i, 0, 0, i};
const std::vector<size_t> upper_bound{i + 1, dot_shape[1], dot_shape[2],
i + 1};
auto slice_out = ConstructNgNode<ngraph::op::Slice>(
op->name(), dot_reshape, lower_bound, upper_bound);
auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape);
tmp_tensors.push_back(reshape_out);
} else {
if (tf_adj_x) {
ng_lhs_axes.push_back(n_dims - 1);
ng_lhs_axes.push_back(n_dims - 2);
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
}
auto concat_op =
ConstructNgNode<ngraph::op::Concat>(op->name(), tmp_tensors, 0);
if (n_dims == 3) {
SaveNgOp(ng_op_map, op->name(), concat_op);
if (tf_adj_y) {
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
} else {
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
}

ng_lhs_shape = ng_lhs->get_shape();
ng_rhs_shape = ng_rhs->get_shape();

if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) {
return errors::InvalidArgument(
"The last dimension of ng_lhs and the first dimension of ng_rhs "
"should have the same size");
}

if (n_dims == 2) {
SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
} else {
SaveNgOp(
ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Reshape>(
op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape));
auto output_shape = ng_lhs_shape;
output_shape[n_dims - 1] = ng_rhs_shape[1];
auto dot_output =
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs);

size_t compound_size = 1;
for (int i = 0; i < out_axes.size(); i++) {
compound_size *= output_shape[i];
}
auto dot_axes = out_axes;
dot_axes.push_back(n_dims - 2);
dot_axes.push_back(n_dims - 1);
for (int i = 0; i < out_axes.size(); i++) {
dot_axes.push_back(n_dims + i);
}
ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2],
ng_rhs_shape[1], compound_size};
std::shared_ptr<ng::Node> dot_reshape;
if (n_dims == 3) {
dot_reshape = dot_output;
} else {
dot_reshape = ConstructNgNode<ngraph::op::Reshape>(
op->name(), dot_output, dot_axes, dot_shape);
}
ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]};
vector<shared_ptr<ngraph::Node>> tmp_tensors;
for (size_t i = 0; i < dot_shape[0]; i++) {
const std::vector<size_t> lower_bound{i, 0, 0, i};
const std::vector<size_t> upper_bound{i + 1, dot_shape[1], dot_shape[2],
i + 1};
auto slice_out = ConstructNgNode<ngraph::op::Slice>(
op->name(), dot_reshape, lower_bound, upper_bound);
auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape);
tmp_tensors.push_back(reshape_out);
}
auto concat_op =
ConstructNgNode<ngraph::op::Concat>(op->name(), tmp_tensors, 0);
if (n_dims == 3) {
SaveNgOp(ng_op_map, op->name(), concat_op);
} else {
SaveNgOp(
ng_op_map, op->name(),
ConstructNgNode<ngraph::op::Reshape>(
op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape));
}
}
}
return Status::OK();
Expand Down
41 changes: 23 additions & 18 deletions test/opexecuter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,29 @@ void OpExecuter::ExecuteOnNGraph(vector<Tensor>& ngraph_outputs,

// Get Tensor input shapes and values from the const nodes
int number_of_inputs = test_op->num_inputs();

// Create nGraph backend
// If NGRAPH_TF_BACKEND is set create that backend
// Else create backend of type ng_backend_name
string ng_backend_type = ng_backend_name;
const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND");

if (ng_backend_env_value != nullptr) {
string backend_env = std::string(ng_backend_env_value);
bool valid_ngraph_tf_backend =
!backend_env.empty() && BackendManager::IsSupportedBackend(backend_env);
ASSERT_TRUE(valid_ngraph_tf_backend) << "NGRAPH_TF_BACKEND " << backend_env
<< " is not a supported backend";
ng_backend_type = backend_env;
}

NGRAPH_VLOG(5) << " Creating NG Backend " << ng_backend_type;
BackendManager::CreateBackend(ng_backend_type);
auto backend = BackendManager::GetBackend(ng_backend_type);

// Add the _ngraph_backend attr to the node
test_op->AddAttr("_ngraph_backend", ng_backend_type);

// TODO : Validate static_input_indexes < number_of_inputs
vector<TensorShape> input_shapes;
vector<DataType> input_dt;
Expand Down Expand Up @@ -328,24 +351,6 @@ void OpExecuter::ExecuteOnNGraph(vector<Tensor>& ngraph_outputs,
NgraphSerialize("unit_test_" + test_op_type_ + ".json", ng_function);
}

// Create nGraph backend
// If NGRAPH_TF_BACKEND is set create that backend
// Else create backend of type ng_backend_name
string ng_backend_type = ng_backend_name;
const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND");
if (ng_backend_env_value != nullptr) {
string backend_env = std::string(ng_backend_env_value);
bool valid_ngraph_tf_backend =
!backend_env.empty() && BackendManager::IsSupportedBackend(backend_env);
ASSERT_TRUE(valid_ngraph_tf_backend) << "NGRAPH_TF_BACKEND " << backend_env
<< " is not a supported backend";
ng_backend_type = backend_env;
}

NGRAPH_VLOG(5) << " Creating NG Backend " << ng_backend_type;
BackendManager::CreateBackend(ng_backend_type);
auto backend = BackendManager::GetBackend(ng_backend_type);

// Allocate tensors for inputs
vector<std::shared_ptr<ngraph::runtime::Tensor>> ng_ip_tensors;
vector<std::shared_ptr<ngraph::runtime::Tensor>> ng_op_tensors;
Expand Down
Loading