Permalink
Browse files

[tf.contrib.seq2seq] Bugfixes to BeamSearchDecoder and GatherTree.

1. Begin the gather tree at the maximum sequence length across all beams (within the batch).
2. Take a second pass starting from t=0 and mask out any beam ids past the *first* beam occurrence of end_token.
3. Update the final sequence lengths to include the first <eos> token in the beam.
4. Update dynamic_decode to allow the BeamSearchDecoder to keep track of its own "finished" states, as the shuffling in the decoder confused the tracking mechanism in dynamic_decode.  This fixes a bug where beam search decoding stops early.
5. Cap sequence length used in GatherTree to min(max_time, max_seq_len(b)) to avoid accessing memory outside the dimensions of input matrices.

Bugs caught by @bdaskalov on github and Pavel Sountsov.  Proper solution and analysis thanks to Rui Zhao.  Thanks all!

Fixes #13536.

PiperOrigin-RevId: 172471462
  • Loading branch information...
ebrevdo authored and tensorflower-gardener committed Oct 17, 2017
1 parent a1ba9f3 commit 18f89c81d288f191abd56501ec6f86fe29265bdd
@@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
const Tensor& step_ids = ctx->input(0);
const Tensor& parent_ids = ctx->input(1);
const Tensor& sequence_length = ctx->input(2);
const Tensor& max_sequence_lengths = ctx->input(2);
const Tensor& end_token = ctx->input(3);
const TensorShape& step_ids_shape = step_ids.shape();
OP_REQUIRES(
ctx, step_ids_shape.dims() == 3,
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
step_ids_shape.DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
sequence_length.shape().DebugString()));
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[0] (",
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
step_ids_shape.dim_size(1), ")"));
OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[1] (",
sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
step_ids_shape.dim_size(2), ")"));
"max_sequence_lengths must be a vector, saw shape: ",
max_sequence_lengths.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(end_token.shape()),
errors::InvalidArgument("end_token must be a scalar, saw shape: ",
end_token.shape().DebugString()));
OP_REQUIRES(
ctx, step_ids_shape == parent_ids.shape(),
errors::InvalidArgument(
"step_ids.shape must match parent_ids.shape. but shapes are: ",
step_ids_shape.DebugString(), " and ",
parent_ids.shape().DebugString()));
OP_REQUIRES(
ctx,
step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
"max_seqeuence_lengths.shape[0] must match. "
"but shapes are: ",
step_ids_shape.DebugString(), " and ",
max_sequence_lengths.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
typename TTypes<int32>::ConstVec max_seq_lens_t =
max_sequence_lengths.vec<int32>();
typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
const T end_token_value = end_token_t();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
seq_len_t, beams_t);
max_seq_lens_t, end_token_value, beams_t);
}
};
@@ -99,27 +105,29 @@ namespace functor {
template <>
struct GatherTree<CPUDevice, int32> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
typename TTypes<int32, 3>::ConstTensor step_ids,
typename TTypes<int32, 3>::ConstTensor parent_ids,
typename TTypes<int32>::ConstMatrix sequence_length,
typename TTypes<int32, 3>::Tensor beams) {
const int64 max_time = parent_ids.dimension(0);
const int64 batch_size = parent_ids.dimension(1);
const int64 beam_width = parent_ids.dimension(2);
TTypes<int32, 3>::ConstTensor step_ids,
TTypes<int32, 3>::ConstTensor parent_ids,
TTypes<int32>::ConstVec max_sequence_lengths,
const int32 end_token, TTypes<int32, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
beams.setConstant(-1);
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
auto DoWork = [&, ctx, end_token](int start_batch_beam,
int limit_batch_beam) {
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
int32 seq_len_b = sequence_length(batch, beam);
if (seq_len_b <= 0) {
const int32 max_seq_len_b =
Eigen::numext::mini(max_time, max_sequence_lengths(batch));
if (max_seq_len_b <= 0) {
continue;
}
beams(seq_len_b - 1, batch, beam) =
step_ids(seq_len_b - 1, batch, beam);
int32 parent = parent_ids(seq_len_b - 1, batch, beam);
for (int32 level = seq_len_b - 2; level >= 0; --level) {
beams(max_seq_len_b - 1, batch, beam) =
step_ids(max_seq_len_b - 1, batch, beam);
int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
if (parent < 0 || parent > beam_width) {
ctx->SetStatus(
errors::InvalidArgument("Saw invalid parent id ", parent,
@@ -130,14 +138,22 @@ struct GatherTree<CPUDevice, int32> {
beams(level, batch, beam) = step_ids(level, batch, parent);
parent = parent_ids(level, batch, parent);
}
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
if (finished) {
beams(time, batch, beam) = -1;
} else if (beams(time, batch, beam) == end_token) {
finished = true;
}
}
}
};
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
// traversal time step.
const int64 batch_beam_cost =
Eigen::TensorOpCost::DivCost<int32>() +
6 * Eigen::TensorOpCost::AddCost<int32>() +
max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
batch_size * beam_width, batch_beam_cost, DoWork);
@@ -148,24 +164,26 @@ struct GatherTree<CPUDevice, int32> {
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
typename TTypes<T>::ConstMatrix sequence_length, \
typename TTypes<T, 3>::Tensor beams); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
typename TTypes<T, 3>::Tensor beams); \
extern template struct GatherTree<GPUDevice, T>;
DECLARE_GPU_SPEC(int32);
#undef DECLARE_GPU_SPEC
} // end namespace functor
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
GatherTreeOp<GPUDevice, T>);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("GatherTree") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("end_token"), \
GatherTreeOp<GPUDevice, T>);
REGISTER_GPU_KERNEL(int32);
#undef REGISTER_GPU_KERNEL
@@ -31,8 +31,8 @@ struct GatherTree {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstMatrix sequence_length,
typename TTypes<T, 3>::Tensor beams);
TTypes<int32>::ConstVec max_sequence_lengths,
const T end_token, typename TTypes<T, 3>::Tensor beams);
};
} // namespace functor
@@ -29,20 +29,24 @@ template <typename T>
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const int32 beam_width, const T* step_ids,
const T* parent_ids,
const T* sequence_length, T* beams) {
const int32* max_sequence_lengths,
const T end_token, T* beams) {
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
if (seq_len_b <= 0) continue;
const int32 max_seq_len_b =
Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch));
if (max_seq_len_b <= 0) {
continue;
}
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
int32 parent = ldg(parent_ids + initial_beam_ix);
for (int32 level = seq_len_b - 2; level >= 0; --level) {
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
const int32 level_beam_ix = GET_IX(level, beam);
const int32 level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > beam_width) {
@@ -53,6 +57,15 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
parent = ldg(parent_ids + level_parent_ix);
}
}
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
const int32 level_beam_ix = GET_IX(time, beam);
if (finished) {
beams[level_beam_ix] = -1;
} else if (beams[level_beam_ix] == end_token) {
finished = true;
}
}
#undef GET_IX
}
}
@@ -62,8 +75,8 @@ struct GatherTree<GPUDevice, T> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstMatrix sequence_length,
typename TTypes<T, 3>::Tensor beams) {
TTypes<int32>::ConstVec max_sequence_length,
const T end_token, typename TTypes<T, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
@@ -75,7 +88,10 @@ struct GatherTree<GPUDevice, T> {
GatherTreeOpKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
batch_size, max_time, beam_width,
step_ids.data(), parent_ids.data(), sequence_length.data(),
step_ids.data(),
parent_ids.data(),
max_sequence_length.data(),
end_token,
beams.data());
// clang-format on
}
@@ -25,27 +25,27 @@ using shape_inference::ShapeHandle;
REGISTER_OP("GatherTree")
.Input("step_ids: T")
.Input("parent_ids: T")
.Input("sequence_length: T")
.Input("max_sequence_lengths: int32")
.Input("end_token: T")
.Output("beams: T")
.Attr("T: {int32}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle step_ids, parent_ids, sequence_length;
ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token;
// step_ids, parent_ids, and output are all shaped:
// [max_time, batch_size, beam_width].
// sequence_length is shaped [batch_size, beam_width].
// max_sequence_length is shaped [batch_size] and end_token is a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
DimensionHandle batch_size = c->Dim(step_ids, 1);
DimensionHandle beam_width = c->Dim(step_ids, 2);
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token));
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
DimensionHandle batch_size = c->Dim(step_ids, 1);
TF_RETURN_IF_ERROR(
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(
c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size));
ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size);
TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids,
&step_ids_prefix));
c->set_output(0, step_ids);
return tensorflow::Status::OK();
@@ -61,7 +61,8 @@ TODO(ebrevdo): fill in
step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.
sequence_length: `[batch_size, beam_width]`.
max_sequence_lengths: `[batch_size]`.
end_token: `[]`.
beams: `[max_time, batch_size, beam_width]`.
)doc");
@@ -54,15 +54,18 @@ def test_gather_tree(self):
[[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
# sequence_lengths is shaped (batch_size = 2, beam_width = 3)
sequence_lengths = [[3, 3, 3], [3, 3, 3]]
# sequence_lengths is shaped (batch_size = 3)
max_sequence_lengths = [3, 3]
expected_result = np.array(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
[[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
predicted_ids, parent_ids, sequence_lengths)
predicted_ids,
parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=11)
with self.test_session() as sess:
res_ = sess.run(res)
Oops, something went wrong.

0 comments on commit 18f89c8

Please sign in to comment.