Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 90 additions & 256 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/util/logical_reduction.hpp"
#include "ngraph/slice_plan.hpp"

#include "logging/ngraph_log.h"
#include "ngraph_bridge/ngraph_api.h"
Expand Down Expand Up @@ -4410,7 +4411,6 @@ static Status TranslateSqueezeOp(const Node* op,
static Status TranslateStridedSliceOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
// TODO: implement new_axis_mask, ellipsis_mask
shared_ptr<ng::Node> ng_input;
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, &ng_input));

Expand Down Expand Up @@ -4442,277 +4442,111 @@ static Status TranslateStridedSliceOp(
TF_RETURN_IF_ERROR(
GetStaticInputVector(op, 3, static_input_map, &stride_vec));

auto& input_shape = ng_input->get_shape();

// Summary: Convert tf indexes (-inf, inf) to clamped_begin_idx [0, d] and
// clamped_end_idx [-1, d], which are then converted to ngraph indexes [0,
// d]
// tf->ng is done through tf_to_ng, which calls clamper, which converts
// tf->clamped

// Graph/function for tf->cmapled
// | ....... <-- y = max_val (max_val = 5)
// .| .
// . | .
// . | . <-- y = x>=0 ? x : x+max_val
// . |.
// -.-.-.----.------------ <-- y = 0 (for inclusive)
// * * | <-- y = -1 (for exclusive)
// |
// X axis: TF indexes. Y axis: Clamped indexes

// clamper is a function that implements the graph above.
// For inclusive, the graph is clamped at 0 and dim-1
// Given dimension d, [0, d-1] are valid locations.
// -1 represents std::rend(). d represents std::end().
// These two are useful for representing exclusive boundaries for end-ranges
// Example for dim = 3:
// ranges: (-inf,-d)| [-d,0) |[0,d-1]|(d-1,inf)
// TF index: -5 -4 |-3 -2 -1 | 0 1 2 | 3 4 5
// clamped begin (inclusive): 0 0 | 0 1 2 | 0 1 2 | 3 3 3
// clamped end (exclusive): -1 -1 | 0 1 2 | 0 1 2 | 3 3 3
auto clamper = [](int idx, size_t dim, bool inclusive) {
// if idx is in [-(d-1), d-1], then its same for both inclusive and
// exclusive
// The first 2 cases breaks down this range
if (idx >= 0 && idx <= (static_cast<int>(dim) - 1)) {
return idx;
} else if (idx < 0 &&
idx + static_cast<int>(dim) >=
0) { // careful not to do idx >= -dim
// (since dim is unsigned)
return idx + static_cast<int>(
dim); // Type casting to int to enable unambiguous auto
// type inference of return type
} else if (idx > static_cast<int>(dim) - 1) {
return static_cast<int>(dim);
} else if (idx + static_cast<int>(dim) < 0) {
// The next case handles the clamping (differently for inclusive and
// exclusive cases)

// careful not to do idx < -dim (since dim is unsigned)
return 0 - (inclusive ? 0 : 1);
}
// Default case
return 0;
};

auto tf_to_ng = [clamper](int tf_begin_idx, int tf_end_idx, int tf_stride,
size_t dim, bool begin_mask, bool end_mask,
bool shrink_mask) {
// if begin mask is present, depending on stride sign use 0 (std::begin)
// or
// dim-1 (std::rbegin)
// clamped_end_idx could line in [-1, d]
int tf_ignore_begin_if_needed =
begin_mask ? (tf_stride > 0 ? 0 : dim - 1) : tf_begin_idx;
// if end mask is present, depending on stride sign use -1 (std::rend) or
// dim (std::end).
// However note, we cannot set to -1, since it has another meaning, hence
// setting to -(dim+1), which would translate to -1 in clamped coordinates
// take care to convert dim from sixze_t to int
int tf_ignore_end_if_needed =
end_mask ? (tf_stride > 0 ? dim : (-((int)dim + 1))) : tf_end_idx;
// using size_t for clamped_begin_idx because: clamped_begin_idx is
// inclusive, so it must lie in [0, dim-1]
size_t clamped_begin_idx = clamper(tf_ignore_begin_if_needed, dim, true);
int64 clamped_end_idx =
clamper(shrink_mask ? clamped_begin_idx + 1 : tf_ignore_end_if_needed,
dim, false);

// Now we have converted semantically non-monotonic and unbounded TF
// indexes
// (-inf, inf) to bounded and monotonic clamped indexes [-1, d]
// Now we need to convert clamped indexes [-1, d] to ngraph indexes [0, d]
// (taking care of reversal in case of negative strides)

size_t needs_reverse = 0;
size_t ng_begin_idx, ng_end_idx;

if (!shrink_mask) {
if ((int)clamped_begin_idx == clamped_end_idx) {
// Empty due to matching indexes
ng_begin_idx = clamped_begin_idx;
// Type safety: clamped_begin_idx == clamped_end_idx implies,
// clamped_end_idx!=-1 (since clamped_begin_idx cannot be -1), hence
// end
// index assignment is type safe
ng_end_idx = clamped_end_idx;
} else { // In the whole of this else: clamped_begin_idx !=
// clamped_end_idx, so !(a < b) iff a > b and vice versa when
// comparing the indexes
// take care to use (int) typecase when comparing int and size_t
if (((int)clamped_begin_idx < clamped_end_idx) != (tf_stride > 0)) {
// Empty due to mismatching directions
ng_begin_idx = clamped_begin_idx;
// Type safe: since clamped_begin_idx is size_t (>0)
// [0:-4:1] in TF would convert to [0:-1:1] in clamped domain. hence
// we do not assign ng_end_idx = clamped_end_idx (which would not be
// type safe due to the -1)
ng_end_idx = clamped_begin_idx;
// Any assignment where ng_begin_idx = ng_end_idx = x (where 0 <= x
// <=
// d-1) would have worked for the 2 empty cases above
}
// Anything after this is non-empty. Anything before this has dealt
// with
// empty cases
else {
// in this case either (clamped_begin_idx < clamped_end_idx &&
// tf_stride > 0) or (clamped_begin_idx > clamped_end_idx &&
// tf_stride
// < 0)
// that is clamped_begin_idx < clamped_end_idx <==> tf_stride > 0.
// hence using only 1 of the clauses is enough
if (tf_stride > 0) {
ng_begin_idx = clamped_begin_idx;
// Type safety: tf_stride > 0 ==> clamped_begin_idx <
// clamped_end_idx. clamped_begin_idx could be 0,
// which means clamped_end_idx > 0. Hence type-safe
ng_end_idx = clamped_end_idx;
} else { // clamped_begin_idx > clamped_end_idx, tf_stride < 0

// clamped_begin_idx is [0, d] && clamped_begin_idx >
// clamped_end_idx,
// which implies clamped_end_idx is [-1,d-1]
// Type safety: With clamped_end_idx in [-1,d-1],
// dim - 1 - clamped_end_idx is in [0, dim]. Hence type safe
ng_end_idx = dim - 1 - clamped_end_idx;

if (clamped_begin_idx == dim) {
clamped_begin_idx = dim - 1;
}
// Note clamped_begin_idx != dim here.
// If clamped_begin_idx==dim && clamped_end_idx==dim, then "Empty
// due to matching indexes" handles it
// If clamped_begin_idx==dim && clamped_end_idx<dim, then 2 cases:
// tf_stride > 0: then "Empty due to mismatching directions"
// handles it
// tf_stride < 0: Then we set it to dim-1 above
// Consider the case of dim=3, where in tf notation we have:
// [4:1:-1], in clampe notation, we get [3:1:-1], which really
// means
// [2:1:-1]

// Type safety: Since clamped_begin_idx is [0, d-1] here, it is
// type
// safe
ng_begin_idx = dim - 1 - clamped_begin_idx;
needs_reverse = 1;
}
}
// Desired implementation ==>
// SaveNgOp(ng_op_map, op->name(),
// ConstructNgNode<ng::op::StridedSlice>(op->name(), begin_vec,
// end_vec, stride_vec,
// tf_begin_mask, tf_end_mask,
// tf_new_axis_mask, tf_shrink_axis_mask,
// tf_ellipsis_mask));

// Temporarily we are borrowing this implementation from nGraph-core until
// ng::op::StridedSlice is released for use in ngraph-bridge

auto convert_mask_to_axes = [](const int mask) {
ng::AxisSet axes{};
for (auto i = 0; i < sizeof(int) * 8; ++i) {
if ((unsigned char)(mask >> i & 0x01) == 1) {
axes.emplace(i);
}
} else {
// cases when clamped indexes are in [0,d] and hence can be directly
// copied
// TODO: what about tf_begin=d, shrink=T, then clamped_end_idx = d, so a
// 0-d axis.
// But since shrink is on, that is reshaped and the 0-d axis is removed?
// Is that a valid config, as shrink_axis must get an axis with dim = 1,
// right?

ng_begin_idx = clamped_begin_idx;
ng_end_idx = clamped_end_idx;
}
return std::make_tuple(ng_begin_idx, ng_end_idx, std::abs(tf_stride),
needs_reverse);
return axes;
};

auto extract_bit = [](int bit_mask, int bit_location) {
return (bit_mask & (1 << bit_location)) != 0;
};
ng::Shape input_shape = ng_input->get_shape();

auto dim_vec = ng_input->get_shape();
auto in_rank = dim_vec.size();

if (begin_vec.size() > in_rank) {
return errors::InvalidArgument("Index out of range using input dim ",
begin_vec.size(), "; input has only ",
in_rank, " dims");
}

// TODO/Note/Question: Are begin, end and stride vectors are of equal length

// begin, end and stride vectors may not have same size as input rank, hence
// initialize them with 0, dim and 1 respectively
vector<size_t> ng_begin_vec(in_rank, 0), ng_stride_vec(in_rank, 1);
vector<size_t> ng_end_vec(dim_vec);
vector<size_t> ng_needs_reversal(in_rank, 0); // should have been a
// vector<bool>, but it is
// optimized, so tie won't
// work. Hence using size_t
for (size_t dim_idx = 0; dim_idx < begin_vec.size(); dim_idx++) {
std::tie(ng_begin_vec[dim_idx], ng_end_vec[dim_idx], ng_stride_vec[dim_idx],
ng_needs_reversal[dim_idx]) =
tf_to_ng(begin_vec[dim_idx], end_vec[dim_idx], stride_vec[dim_idx],
dim_vec[dim_idx], extract_bit(tf_begin_mask, dim_idx),
extract_bit(tf_end_mask, dim_idx),
extract_bit(tf_shrink_axis_mask, dim_idx));
}

// filter out negative stride dimensions
vector<size_t> neg_strides;
for (size_t dim_idx = 0; dim_idx < in_rank; dim_idx++) {
if (ng_needs_reversal[dim_idx]) {
neg_strides.push_back(dim_idx);
std::vector<int64_t> begin_vec_longint(begin_vec.begin(), begin_vec.end());
std::vector<int64_t> end_vec_longint(end_vec.begin(), end_vec.end());
std::vector<int64_t> stride_vec_longint(stride_vec.begin(), stride_vec.end());

NGRAPH_VLOG(4) << "Arguments to make_slice_plan: Input shape: " << input_shape
<< ", begin vector: " << ng::join(begin_vec_longint)
<< ", end vector: " << ng::join(end_vec_longint)
<< ", stride vector: " << ng::join(stride_vec_longint)
<< ", begin mask: " << tf_begin_mask
<< ", end mask: " << tf_end_mask
<< ", new axis mask: " << tf_new_axis_mask
<< ", shrink axis mask: " << tf_shrink_axis_mask
<< ", ellipsis mask: " << tf_ellipsis_mask;

auto in_rank = ng_input->get_shape().size();
if (tf_new_axis_mask == 0) {
if (begin_vec_longint.size() > in_rank) {
return errors::InvalidArgument("Index out of range using input dim ",
begin_vec_longint.size(),
"; input has only ", in_rank, " dims");
}
}

// atleast one stride was negative, in which case reverse the input
if (neg_strides.size() > 0)
ng_input =
ConstructNgNode<ng::op::Reverse>(op->name(), ng_input, neg_strides);
NGRAPH_VLOG(3) << "NG Lower Vector " << ng::join(ng_begin_vec);
NGRAPH_VLOG(3) << "NG End Vector " << ng::join(ng_end_vec);
NGRAPH_VLOG(3) << "NG Stride Vector " << ng::join(ng_stride_vec);
NGRAPH_VLOG(3) << "NG Needs Reversal: " << ng::join(ng_needs_reversal);

std::shared_ptr<ng::Node> ng_strided_slice = ConstructNgNode<ng::op::Slice>(
op->name(), ng_input, ng_begin_vec, ng_end_vec, ng_stride_vec);

if (tf_shrink_axis_mask) {
int64 shrink_axis_mask = tf_shrink_axis_mask;
vector<size_t> output_shape;

// Note: do not use rank instead of ng_begin_vec.size()
// since ng_begin_vec.size() can be less than rank, and
// shrink_mask will have atmost ng_begin_vec.size() elements
for (size_t i = 0; i < ng_begin_vec.size(); i++) {
if ((shrink_axis_mask & 1) != 1) {
output_shape.push_back(ng_end_vec[i] - ng_begin_vec[i]);
} else {
// TODO: must it equal 1 or can it be 0 too?
if (ng_end_vec[i] - ng_begin_vec[i] > 1)
return errors::InvalidArgument(
"Trying to shrink specification ", i,
"where tf begin, end, strides are: ", begin_vec[i], ":",
end_vec[i], ":", stride_vec[i],
". nGraph begin, end, stride are: ", ng_begin_vec[i], ":",
ng_end_vec[i], ":", ng_stride_vec[i],
". nGraph's begin and end have difference greater than 1");
}
shrink_axis_mask >>= 1;
}
auto sp = ng::make_slice_plan(
input_shape, begin_vec_longint, end_vec_longint, stride_vec_longint,
convert_mask_to_axes(tf_begin_mask), convert_mask_to_axes(tf_end_mask),
convert_mask_to_axes(tf_new_axis_mask),
convert_mask_to_axes(tf_shrink_axis_mask),
convert_mask_to_axes(tf_ellipsis_mask));

NGRAPH_VLOG(4) << "Return values of make_slice_plan: begin: "
<< ng::join(sp.begins) << ", end: " << ng::join(sp.ends)
<< ", stride: " << ng::join(sp.strides)
<< ", reshape input shape: " << sp.reshape_in_shape
<< ", reshape output shape: " << sp.reshape_out_shape
<< ", reverse axis: " << sp.reverse_axes;

// To handle cases like x[2:2], where shape(x) = [1],
// TF returns shape = [0], empty vector
// make_slice_plan returns begin=2, end=2, but that is > 1
// So must clamp them
// Another example:
// for dimension 3, Also 2:3:-1 gives 4:4, which will also fail if we try to
// construct slice. So must clamp to 2:2 etc

auto clamp = [](int64_t x, int64_t min, int64_t max) {
return x > max ? max : (x < min ? min : x);
};
for (int i = 0; i < sp.begins.size(); i++) {
sp.begins[i] = clamp(sp.begins[i], 0, input_shape[i]);
sp.ends[i] = clamp(sp.ends[i], 0, input_shape[i]);
}

NGRAPH_VLOG(3) << "Shrink axis mask " << tf_shrink_axis_mask;
ng::Shape ng_final_shape(output_shape);
ng::AxisVector ng_axis_order(input_shape.size());
// Need to convert int64_t to size_t
std::vector<size_t> sp_begins(sp.begins.begin(), sp.begins.end());
std::vector<size_t> sp_ends(sp.ends.begin(), sp.ends.end());
std::vector<size_t> sp_strides(sp.strides.begin(), sp.strides.end());

shared_ptr<ng::Node> ng_result = ConstructNgNode<ng::op::Slice>(
op->name(), ng_input, sp_begins, sp_ends, sp_strides);

if (sp.reshape_in_shape != sp.reshape_out_shape) {
ng::Shape ng_out_shape(sp.reshape_out_shape);
ng::AxisVector ng_axis_order(sp.reshape_in_shape.size());
// std::iota Fills the range [first, last) with sequentially increasing
// values,
// starting with value and repetitively evaluating ++value
std::iota(ng_axis_order.begin(), ng_axis_order.end(), 0);

NGRAPH_VLOG(3) << " Output shape " << ng::join(output_shape);
NGRAPH_VLOG(3) << " Output shape " << ng::join(ng_out_shape);
NGRAPH_VLOG(3) << " NG axis order " << ng::join(ng_axis_order);

ng_strided_slice = ConstructNgNode<ng::op::Reshape>(
op->name(), ng_strided_slice, ng_axis_order, ng_final_shape);
ng_result = ConstructNgNode<ng::op::Reshape>(op->name(), ng_result,
ng_axis_order, ng_out_shape);
}

// TODO: assert size in this dim was 1
// TODO: assert new_axis_mask and tf_shrink_axis_mask are not set at the
// same
// time?
// TODO: tf_new_axis_mask can exceed rank
if (!sp.reverse_axes.empty()) {
ng_result = ConstructNgNode<ng::op::Reverse>(op->name(), ng_result,
sp.reverse_axes);
}

SaveNgOp(ng_op_map, op->name(), ng_strided_slice);
SaveNgOp(ng_op_map, op->name(), ng_result);
return Status::OK();
}

Expand Down
Loading