New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to cuDNN CTC loss #32302
Changes from 21 commits
a98e8ca
5a07e2c
6ff2298
c3e8f6f
1ab863f
46aa1ca
4ef99df
0ae9214
33d4b5a
deecd42
0b9feec
4a89f04
cbf169c
9966ed2
cb7e008
0681377
47306c8
ecdaf8e
7ee06aa
f9e38a4
bb87219
195729d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
op { | ||
graph_op_name: "CTCLossV2" | ||
visibility: HIDDEN | ||
in_arg { | ||
name: "inputs" | ||
description: <<END | ||
3-D, shape: `(max_time x batch_size x num_classes)`, the logits. Default blank | ||
label is 0 rather num_classes - 1. | ||
END | ||
} | ||
in_arg { | ||
name: "labels_indices" | ||
description: <<END | ||
The indices of a `SparseTensor<int32, 2>`. | ||
`labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for | ||
`(batch b, time t)`. | ||
END | ||
} | ||
in_arg { | ||
name: "labels_values" | ||
description: <<END | ||
The values (labels) associated with the given batch and time. | ||
END | ||
} | ||
in_arg { | ||
name: "sequence_length" | ||
description: <<END | ||
A vector containing sequence lengths (batch). | ||
END | ||
} | ||
out_arg { | ||
name: "loss" | ||
description: <<END | ||
A vector (batch) containing log-probabilities. | ||
END | ||
} | ||
out_arg { | ||
name: "gradient" | ||
description: <<END | ||
The gradient of `loss`. 3-D, shape: | ||
`(max_time x batch_size x num_classes)`. | ||
END | ||
} | ||
attr { | ||
name: "preprocess_collapse_repeated" | ||
description: <<END | ||
Scalar, if true then repeated labels are | ||
collapsed prior to the CTC calculation. | ||
END | ||
} | ||
attr { | ||
name: "ctc_merge_repeated" | ||
description: <<END | ||
Scalar. If set to false, *during* CTC calculation | ||
repeated non-blank labels will not be merged and are interpreted as | ||
individual labels. This is a simplified version of CTC. | ||
END | ||
} | ||
attr { | ||
name: "ignore_longer_outputs_than_inputs" | ||
description: <<END | ||
Scalar. If set to true, during CTC | ||
calculation, items that have longer output sequences than input sequences | ||
are skipped: they don't contribute to the loss term and have zero-gradient. | ||
END | ||
} | ||
summary: "Calculates the CTC Loss (log probability) for each batch entry. Also calculates" | ||
description: <<END | ||
the gradient. This class performs the softmax operation for you, so inputs | ||
should be e.g. linear projections of outputs by an LSTM. | ||
END | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be visibility:hidden so it doesn't show up as tf.ctc_loss_v2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,10 @@ limitations under the License. | |
|
||
// See docs in ../ops/ctc_ops.cc. | ||
|
||
#if GOOGLE_CUDA | ||
#define EIGEN_USE_GPU | ||
#endif // GOOGLE_CUDA | ||
|
||
#include "tensorflow/core/framework/bounds_check.h" | ||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
|
@@ -25,8 +29,39 @@ limitations under the License. | |
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h" | ||
#include "tensorflow/core/util/sparse/sparse_tensor.h" | ||
|
||
#if GOOGLE_CUDA | ||
#include "third_party/gpus/cudnn/cudnn.h" | ||
#include "tensorflow/core/util/tensor_format.h" | ||
#include "tensorflow/core/kernels/conv_ops_gpu.h" | ||
#include "tensorflow/core/util/stream_executor_util.h" | ||
#endif // GOOGLE_CUDA | ||
|
||
namespace tensorflow { | ||
|
||
typedef Eigen::ThreadPoolDevice CPUDevice; | ||
#if GOOGLE_CUDA | ||
using GPUDevice = Eigen::GpuDevice; | ||
|
||
namespace { | ||
using se::Stream; | ||
using se::StreamExecutor; | ||
using se::dnn::RnnStateTensorDescriptor; | ||
using se::dnn::ToDataType; | ||
|
||
template<typename T> | ||
void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices, | ||
int num_indices, int batch_size, | ||
std::vector<int> *labels_lengths) { | ||
const T* h_in = labels_indices->flat<T>().data(); | ||
for(int i = 0; i < num_indices; i++) { | ||
const T& key = h_in[i * 2]; | ||
(*labels_lengths)[key]++; | ||
} | ||
} | ||
|
||
} // end namespace | ||
#endif // GOOGLE_CUDA | ||
|
||
template <typename T> | ||
class CTCLossOp : public OpKernel { | ||
typedef Eigen::Map< | ||
|
@@ -186,4 +221,152 @@ REGISTER_CPU(double); | |
|
||
#undef REGISTER_CPU | ||
|
||
#if GOOGLE_CUDA && CUDNN_VERSION >= 7603 | ||
class CTCLossOpGPU : public OpKernel { | ||
public: | ||
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { | ||
bool preprocess_collapse_repeated; | ||
bool ctc_merge_repeated; | ||
bool ignore_longer_outputs_than_inputs; | ||
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated", | ||
&preprocess_collapse_repeated)); | ||
OP_REQUIRES_OK(ctx, | ||
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated)); | ||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs", | ||
&ignore_longer_outputs_than_inputs)); | ||
|
||
OP_REQUIRES(ctx, !preprocess_collapse_repeated, | ||
errors::InvalidArgument("GPU CTCLossOp requires " | ||
"preprocess_collapse_repeated to be " | ||
"false")); | ||
OP_REQUIRES(ctx, ctc_merge_repeated, | ||
errors::InvalidArgument("GPU CTCLossOp requires " | ||
"ctc_merge_repeated to be " | ||
"true")); | ||
OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs, | ||
errors::InvalidArgument("GPU CTCLossOp requires " | ||
"ignore_longer_outputs_than_inputs to" | ||
"be false")); | ||
} | ||
|
||
void Compute(OpKernelContext* ctx) override { | ||
const Tensor* inputs; | ||
const Tensor* labels_indices; | ||
const Tensor* labels_values; | ||
const Tensor* seq_len; | ||
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs)); | ||
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices)); | ||
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values)); | ||
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len)); | ||
|
||
OP_REQUIRES(ctx, inputs->shape().dims() == 3, | ||
errors::InvalidArgument("inputs is not a 3-Tensor")); | ||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()), | ||
errors::InvalidArgument("sequence_length is not a vector")); | ||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()), | ||
errors::InvalidArgument("labels_indices is not a matrix")); | ||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()), | ||
errors::InvalidArgument("labels_values is not a vector")); | ||
|
||
const TensorShape& inputs_shape = inputs->shape(); | ||
const int64 max_time_raw = inputs_shape.dim_size(0); | ||
const int64 batch_size_raw = inputs_shape.dim_size(1); | ||
const int64 num_classes_raw = inputs_shape.dim_size(2); | ||
OP_REQUIRES( | ||
ctx, FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()), | ||
errors::InvalidArgument("max_time_ cannot exceed max int")); | ||
OP_REQUIRES( | ||
ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()), | ||
errors::InvalidArgument("batch_size cannot exceed max int")); | ||
OP_REQUIRES( | ||
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), | ||
errors::InvalidArgument("num_classes cannot exceed max int")); | ||
Comment on lines
+281
to
+283
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we not bounds checking |
||
const int max_time = static_cast<const int>(max_time_raw); | ||
const int batch_size = static_cast<const int>(batch_size_raw); | ||
const int num_classes = static_cast<const int>(num_classes_raw); | ||
|
||
OP_REQUIRES( | ||
ctx, batch_size == seq_len->dim_size(0), | ||
errors::InvalidArgument("len(sequence_length) != batch_size. ", | ||
"len(sequence_length): ", seq_len->dim_size(0), | ||
" batch_size: ", batch_size)); | ||
|
||
OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0), | ||
errors::InvalidArgument( | ||
"labels_indices and labels_values must contain the " | ||
"same number of rows, but saw shapes: ", | ||
labels_indices->shape().DebugString(), " vs. ", | ||
labels_values->shape().DebugString())); | ||
auto num_indices = labels_indices->dim_size(0); | ||
|
||
OP_REQUIRES(ctx, batch_size != 0, | ||
errors::InvalidArgument("batch_size must not be 0")); | ||
|
||
Tensor* loss = nullptr; | ||
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss)); | ||
|
||
Tensor* gradient = nullptr; | ||
OP_REQUIRES_OK(ctx, | ||
ctx->allocate_output("gradient", inputs_shape, &gradient)); | ||
|
||
|
||
// Convert the labels_indices to labels_lengths. | ||
std::vector<int> labels_lengths(batch_size, 0); | ||
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size, | ||
&labels_lengths); | ||
|
||
StreamExecutor* executor = ctx->op_device_context()->stream()->parent(); | ||
se::dnn::DataType data_type = ToDataType<float>::value; | ||
|
||
auto probs_desc_s = executor->createRnnStateTensorDescriptor( | ||
max_time, batch_size, num_classes, data_type); | ||
OP_REQUIRES_OK(ctx, probs_desc_s.status()); | ||
std::unique_ptr<RnnStateTensorDescriptor> probs_desc = | ||
probs_desc_s.ConsumeValueOrDie(); | ||
|
||
auto grads_desc_s = executor->createRnnStateTensorDescriptor( | ||
max_time, batch_size, num_classes, data_type); | ||
OP_REQUIRES_OK(ctx, grads_desc_s.status()); | ||
std::unique_ptr<RnnStateTensorDescriptor> grads_desc = | ||
grads_desc_s.ConsumeValueOrDie(); | ||
|
||
absl::Span<const int32> labels_data(labels_values->flat<int32>().data(), | ||
num_indices); | ||
absl::Span<const int32> labels_lengths_data(labels_lengths.data(), | ||
batch_size); | ||
absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(), | ||
batch_size); | ||
|
||
auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs); | ||
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss); | ||
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient); | ||
|
||
// Set the memory limitation to 4GB for workspace memory. | ||
DnnScratchAllocator workspace_allocator(1LL << 32, ctx); | ||
|
||
Stream* stream = ctx->op_device_context()->stream(); | ||
bool cudnn_launch_status = | ||
stream | ||
->ThenCtcLoss( | ||
*probs_desc, probs_data, labels_data, labels_lengths_data, | ||
input_lengths_data, &costs_data, *grads_desc, &grads_data, | ||
&workspace_allocator) | ||
.ok(); | ||
|
||
if (!cudnn_launch_status) { | ||
ctx->SetStatus( | ||
errors::Internal("cuDNN CTCLoss launch failure")); | ||
} | ||
} | ||
|
||
private: | ||
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU); | ||
}; | ||
|
||
REGISTER_KERNEL_BUILDER(Name("CTCLossV2").Device(DEVICE_GPU) | ||
.HostMemory("labels_indices") | ||
.HostMemory("labels_values") | ||
.HostMemory("sequence_length"), | ||
CTCLossOpGPU); | ||
Comment on lines
+366
to
+370
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we register the kernel even if cuDNN is older than 7.6.3 (and the kernel is guaranteed to fail)? |
||
#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603 | ||
} // end namespace tensorflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss") | |
return Status::OK(); | ||
}); | ||
|
||
REGISTER_OP("CTCLossV2") | ||
.Input("inputs: float") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the CuDNN implementation support types other than float? If so, we should also support them here. #31164 added support for double for CTCLossOp, for example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, CuDNN only support the float CTCLoss. |
||
.Input("labels_indices: int64") | ||
.Input("labels_values: int32") | ||
.Input("sequence_length: int32") | ||
.Attr("preprocess_collapse_repeated: bool = false") | ||
.Attr("ctc_merge_repeated: bool = true") | ||
.Attr("ignore_longer_outputs_than_inputs: bool = false") | ||
.Output("loss: float") | ||
.Output("gradient: float") | ||
.SetShapeFn([](InferenceContext* c) { | ||
ShapeHandle inputs; | ||
ShapeHandle labels_indices; | ||
ShapeHandle labels_values; | ||
ShapeHandle sequence_length; | ||
|
||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); | ||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices)); | ||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values)); | ||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length)); | ||
|
||
DimensionHandle unused; | ||
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0), | ||
c->Dim(labels_values, 0), &unused)); | ||
|
||
// Get batch size from inputs and sequence_length, and update inputs | ||
// with the merged batch_size since it is returned. | ||
DimensionHandle batch_size; | ||
TF_RETURN_IF_ERROR( | ||
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); | ||
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs)); | ||
|
||
c->set_output(0, c->Vector(batch_size)); | ||
c->set_output(1, inputs); | ||
return Status::OK(); | ||
}); | ||
|
||
REGISTER_OP("CTCGreedyDecoder") | ||
.Input("inputs: T") | ||
.Input("sequence_length: int32") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a line here saying "visibility: HIDDEN"; this will prevent the generation of a tf.ctc_loss_v2 API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Done.