Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to cuDNN CTC loss #32302

Merged
merged 22 commits into from Jan 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a98e8ca
Add changes to support cuDNN CTC loss
kaixih Jul 10, 2019
5a07e2c
CPU CTC tests without V2 and update goldens
kaixih Sep 9, 2019
6ff2298
Added pbtxt for ctc loss v2
kaixih Sep 9, 2019
c3e8f6f
Changed some positions of macros for cuDNN CTC loss
kaixih Sep 9, 2019
1ab863f
Switch to non-deterministic algo which allow larger label size
kaixih Sep 13, 2019
46aa1ca
Put the reusable class CudnnAllocatorInTemp to a separate file
kaixih Nov 8, 2019
4ef99df
Simplified the CtcLossDescriptor
kaixih Nov 18, 2019
0ae9214
Added ElementType and DeviceMemoryBase for CTC Loss
kaixih Nov 19, 2019
33d4b5a
Formatting
kaixih Nov 19, 2019
deecd42
Modified the macros to ompile with old cudnn version
kaixih Nov 19, 2019
0b9feec
Move the CtcLossDescriptor constructor/destructor back to the header
kaixih Dec 2, 2019
4a89f04
Use DnnScratchAllocator
kaixih Dec 5, 2019
cbf169c
Variables init and decl in one line; check attr in constructor; check…
kaixih Dec 5, 2019
9966ed2
Remove empty lines
kaixih Dec 5, 2019
cb7e008
Avoid to register the ctc loss kernel when cudnn is older than 7.6.3
kaixih Dec 5, 2019
0681377
Remove the empty CtcLossDescriptor
kaixih Dec 6, 2019
47306c8
Solved a conflict
kaixih Dec 11, 2019
ecdaf8e
remove empty CtcLossDescriptor in dnn.h
kaixih Dec 17, 2019
7ee06aa
Add impl selector to remove the env var about the CUDNN CTC Loss
kaixih Dec 5, 2019
f9e38a4
remove unused import
kaixih Jan 9, 2020
bb87219
Set CTCLossV2 to visibility:HIDDEN
kaixih Jan 9, 2020
195729d
update goldens
kaixih Jan 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 72 additions & 0 deletions tensorflow/core/api_def/base_api/api_def_CTCLossV2.pbtxt
@@ -0,0 +1,72 @@
op {
graph_op_name: "CTCLossV2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a line here saying "visibility: HIDDEN"; this will prevent the generation of a tf.ctc_loss_v2 API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Done.

visibility: HIDDEN
in_arg {
name: "inputs"
description: <<END
3-D, shape: `(max_time x batch_size x num_classes)`, the logits. Default blank
label is 0 rather num_classes - 1.
END
}
in_arg {
name: "labels_indices"
description: <<END
The indices of a `SparseTensor<int32, 2>`.
`labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for
`(batch b, time t)`.
END
}
in_arg {
name: "labels_values"
description: <<END
The values (labels) associated with the given batch and time.
END
}
in_arg {
name: "sequence_length"
description: <<END
A vector containing sequence lengths (batch).
END
}
out_arg {
name: "loss"
description: <<END
A vector (batch) containing log-probabilities.
END
}
out_arg {
name: "gradient"
description: <<END
The gradient of `loss`. 3-D, shape:
`(max_time x batch_size x num_classes)`.
END
}
attr {
name: "preprocess_collapse_repeated"
description: <<END
Scalar, if true then repeated labels are
collapsed prior to the CTC calculation.
END
}
attr {
name: "ctc_merge_repeated"
description: <<END
Scalar. If set to false, *during* CTC calculation
repeated non-blank labels will not be merged and are interpreted as
individual labels. This is a simplified version of CTC.
END
}
attr {
name: "ignore_longer_outputs_than_inputs"
description: <<END
Scalar. If set to true, during CTC
calculation, items that have longer output sequences than input sequences
are skipped: they don't contribute to the loss term and have zero-gradient.
END
}
summary: "Calculates the CTC Loss (log probability) for each batch entry. Also calculates"
description: <<END
the gradient. This class performs the softmax operation for you, so inputs
should be e.g. linear projections of outputs by an LSTM.
END
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be visibility:hidden so it doesn't show up as tf.ctc_loss_v2

6 changes: 5 additions & 1 deletion tensorflow/core/kernels/BUILD
Expand Up @@ -2356,7 +2356,11 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
],
] + if_cuda([
":gpu_utils",
":conv_ops_gpu_hdrs",
"@local_config_cuda//cuda:cudnn_header",
]),
)

tf_cc_test(
Expand Down
183 changes: 183 additions & 0 deletions tensorflow/core/kernels/ctc_loss_op.cc
Expand Up @@ -15,6 +15,10 @@ limitations under the License.

// See docs in ../ops/ctc_ops.cc.

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA

#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand All @@ -25,8 +29,39 @@ limitations under the License.
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;

namespace {
using se::Stream;
using se::StreamExecutor;
using se::dnn::RnnStateTensorDescriptor;
using se::dnn::ToDataType;

template<typename T>
void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
int num_indices, int batch_size,
std::vector<int> *labels_lengths) {
const T* h_in = labels_indices->flat<T>().data();
for(int i = 0; i < num_indices; i++) {
const T& key = h_in[i * 2];
(*labels_lengths)[key]++;
}
}

} // end namespace
#endif // GOOGLE_CUDA

template <typename T>
class CTCLossOp : public OpKernel {
typedef Eigen::Map<
Expand Down Expand Up @@ -186,4 +221,152 @@ REGISTER_CPU(double);

#undef REGISTER_CPU

#if GOOGLE_CUDA && CUDNN_VERSION >= 7603
class CTCLossOpGPU : public OpKernel {
public:
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
bool preprocess_collapse_repeated;
bool ctc_merge_repeated;
bool ignore_longer_outputs_than_inputs;
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
&preprocess_collapse_repeated));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
&ignore_longer_outputs_than_inputs));

OP_REQUIRES(ctx, !preprocess_collapse_repeated,
errors::InvalidArgument("GPU CTCLossOp requires "
"preprocess_collapse_repeated to be "
"false"));
OP_REQUIRES(ctx, ctc_merge_repeated,
errors::InvalidArgument("GPU CTCLossOp requires "
"ctc_merge_repeated to be "
"true"));
OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs,
errors::InvalidArgument("GPU CTCLossOp requires "
"ignore_longer_outputs_than_inputs to"
"be false"));
}

void Compute(OpKernelContext* ctx) override {
const Tensor* inputs;
const Tensor* labels_indices;
const Tensor* labels_values;
const Tensor* seq_len;
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));

OP_REQUIRES(ctx, inputs->shape().dims() == 3,
errors::InvalidArgument("inputs is not a 3-Tensor"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
errors::InvalidArgument("sequence_length is not a vector"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
errors::InvalidArgument("labels_indices is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
errors::InvalidArgument("labels_values is not a vector"));

const TensorShape& inputs_shape = inputs->shape();
const int64 max_time_raw = inputs_shape.dim_size(0);
const int64 batch_size_raw = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
OP_REQUIRES(
ctx, FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("max_time_ cannot exceed max int"));
OP_REQUIRES(
ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("batch_size cannot exceed max int"));
OP_REQUIRES(
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("num_classes cannot exceed max int"));
Comment on lines +281 to +283
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we not bounds checking max_time_raw and batch_size_raw also?

const int max_time = static_cast<const int>(max_time_raw);
const int batch_size = static_cast<const int>(batch_size_raw);
const int num_classes = static_cast<const int>(num_classes_raw);

OP_REQUIRES(
ctx, batch_size == seq_len->dim_size(0),
errors::InvalidArgument("len(sequence_length) != batch_size. ",
"len(sequence_length): ", seq_len->dim_size(0),
" batch_size: ", batch_size));

OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
errors::InvalidArgument(
"labels_indices and labels_values must contain the "
"same number of rows, but saw shapes: ",
labels_indices->shape().DebugString(), " vs. ",
labels_values->shape().DebugString()));
auto num_indices = labels_indices->dim_size(0);

OP_REQUIRES(ctx, batch_size != 0,
errors::InvalidArgument("batch_size must not be 0"));

Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));

Tensor* gradient = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));


// Convert the labels_indices to labels_lengths.
std::vector<int> labels_lengths(batch_size, 0);
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size,
&labels_lengths);

StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
se::dnn::DataType data_type = ToDataType<float>::value;

auto probs_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, probs_desc_s.status());
std::unique_ptr<RnnStateTensorDescriptor> probs_desc =
probs_desc_s.ConsumeValueOrDie();

auto grads_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, grads_desc_s.status());
std::unique_ptr<RnnStateTensorDescriptor> grads_desc =
grads_desc_s.ConsumeValueOrDie();

absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
num_indices);
absl::Span<const int32> labels_lengths_data(labels_lengths.data(),
batch_size);
absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(),
batch_size);

auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);

// Set the memory limitation to 4GB for workspace memory.
DnnScratchAllocator workspace_allocator(1LL << 32, ctx);

Stream* stream = ctx->op_device_context()->stream();
bool cudnn_launch_status =
stream
->ThenCtcLoss(
*probs_desc, probs_data, labels_data, labels_lengths_data,
input_lengths_data, &costs_data, *grads_desc, &grads_data,
&workspace_allocator)
.ok();

if (!cudnn_launch_status) {
ctx->SetStatus(
errors::Internal("cuDNN CTCLoss launch failure"));
}
}

private:
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
};

REGISTER_KERNEL_BUILDER(Name("CTCLossV2").Device(DEVICE_GPU)
.HostMemory("labels_indices")
.HostMemory("labels_values")
.HostMemory("sequence_length"),
CTCLossOpGPU);
Comment on lines +366 to +370
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we register the kernel even if cuDNN is older than 7.6.3 (and the kernel is guaranteed to fail)?

#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603
} // end namespace tensorflow
37 changes: 37 additions & 0 deletions tensorflow/core/ops/ctc_ops.cc
Expand Up @@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss")
return Status::OK();
});

REGISTER_OP("CTCLossV2")
.Input("inputs: float")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the CuDNN implementation support types other than float? If so, we should also support them here.

#31164 added support for double for CTCLossOp, for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, CuDNN only support the float CTCLoss.

.Input("labels_indices: int64")
.Input("labels_values: int32")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
.Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle labels_indices;
ShapeHandle labels_values;
ShapeHandle sequence_length;

TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));

DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
c->Dim(labels_values, 0), &unused));

// Get batch size from inputs and sequence_length, and update inputs
// with the merged batch_size since it is returned.
DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));

c->set_output(0, c->Vector(batch_size));
c->set_output(1, inputs);
return Status::OK();
});

REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: T")
.Input("sequence_length: int32")
Expand Down