Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CTC for float64 #31164

Merged
merged 2 commits into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 41 additions & 24 deletions tensorflow/core/kernels/ctc_decoder_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
Expand All @@ -33,11 +34,11 @@ namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
int* c) {
template<typename T>
inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r, int* c) {
*c = 0;
CHECK_LT(0, m.dimension(1));
float p = m(r, 0);
auto p = m(r, 0);
for (int i = 1; i < m.dimension(1); ++i) {
if (m(r, i) > p) {
p = m(r, i);
Expand Down Expand Up @@ -170,6 +171,7 @@ class CTCDecodeHelper {
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};

template<typename T>
class CTCGreedyDecoderOp : public OpKernel {
public:
explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
Expand All @@ -189,7 +191,7 @@ class CTCGreedyDecoderOp : public OpKernel {

const TensorShape& inputs_shape = inputs->shape();

std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
const int64 max_time = inputs_shape.dim_size(0);
const int64 batch_size = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
Expand All @@ -198,14 +200,14 @@ class CTCGreedyDecoderOp : public OpKernel {
errors::InvalidArgument("num_classes cannot exceed max int"));
const int num_classes = static_cast<const int>(num_classes_raw);

auto inputs_t = inputs->tensor<float, 3>();
auto inputs_t = inputs->tensor<T, 3>();

for (std::size_t t = 0; t < max_time; ++t) {
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
batch_size, num_classes);
}
auto seq_len_t = seq_len->vec<int32>();
auto log_prob_t = log_prob->matrix<float>();
auto log_prob_t = log_prob->matrix<T>();

log_prob_t.setZero();

Expand All @@ -221,7 +223,7 @@ class CTCGreedyDecoderOp : public OpKernel {
int prev_indices = -1;
for (int t = 0; t < seq_len_t(b); ++t) {
int max_class_indices;
log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
log_prob_t(b, 0) += -RowMax<T>(input_list_t[t], b, &max_class_indices);
if (max_class_indices != blank_index &&
!(merge_repeated_ && max_class_indices == prev_indices)) {
sequence.push_back(max_class_indices);
Expand Down Expand Up @@ -250,10 +252,18 @@ class CTCGreedyDecoderOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
};

REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
CTCGreedyDecoderOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCGreedyDecoderOp<T>);

REGISTER_CPU(float);
REGISTER_CPU(double);

#undef REGISTER_CPU

// CTC beam search
template<typename T>
class CTCBeamSearchDecoderOp : public OpKernel {
public:
explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
Expand All @@ -275,9 +285,9 @@ class CTCBeamSearchDecoderOp : public OpKernel {
ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
&decoded_values, &decoded_shape));

auto inputs_t = inputs->tensor<float, 3>();
auto inputs_t = inputs->tensor<T, 3>();
auto seq_len_t = seq_len->vec<int32>();
auto log_prob_t = log_prob->matrix<float>();
auto log_prob_t = log_prob->matrix<T>();

const TensorShape& inputs_shape = inputs->shape();

Expand All @@ -291,30 +301,30 @@ class CTCBeamSearchDecoderOp : public OpKernel {

log_prob_t.setZero();

std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;

for (std::size_t t = 0; t < max_time; ++t) {
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
batch_size, num_classes);
}

ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
&beam_scorer_, 1 /* batch_size */,
merge_repeated_);
Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
auto input_chip_t = input_chip.flat<float>();
ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
&beam_scorer_, 1 /* batch_size */,
merge_repeated_);
Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
auto input_chip_t = input_chip.flat<T>();

std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
std::vector<float> log_probs;
std::vector<T> log_probs;

// Assumption: the blank index is num_classes - 1
for (int b = 0; b < batch_size; ++b) {
auto& best_paths_b = best_paths[b];
best_paths_b.resize(decode_helper_.GetTopPaths());
for (int t = 0; t < seq_len_t(b); ++t) {
input_chip_t = input_list_t[t].chip(b, 0);
auto input_bi =
Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>
(input_chip_t.data(), num_classes);
beam_search.Step(input_bi);
}
OP_REQUIRES_OK(
Expand All @@ -335,13 +345,20 @@ class CTCBeamSearchDecoderOp : public OpKernel {

private:
CTCDecodeHelper decode_helper_;
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
bool merge_repeated_;
int beam_width_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
};

REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
CTCBeamSearchDecoderOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCBeamSearchDecoderOp<T>);

REGISTER_CPU(float);
REGISTER_CPU(double);

#undef REGISTER_CPU

} // end namespace tensorflow
30 changes: 19 additions & 11 deletions tensorflow/core/kernels/ctc_loss_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

template<typename T>
class CTCLossOp : public OpKernel {
typedef Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic,
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor> >
InputMap;
typedef Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
OutputMap;

public:
Expand Down Expand Up @@ -110,7 +110,7 @@ class CTCLossOp : public OpKernel {
errors::InvalidArgument("label SparseTensor is not valid: ",
labels_sp_valid.error_message()));

ctc::CTCLossCalculator::LabelSequences labels_t(batch_size);
typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size);
for (const auto& g : labels_sp.group({0})) { // iterate by batch
const int64 batch_indices = g.group()[0];
OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
Expand All @@ -137,13 +137,13 @@ class CTCLossOp : public OpKernel {

Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
auto loss_t = loss->vec<float>();
auto loss_t = loss->vec<T>();

Tensor* gradient;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));
auto gradient_t = gradient->tensor<float, 3>();
auto inputs_t = inputs->tensor<float, 3>();
auto gradient_t = gradient->tensor<T, 3>();
auto inputs_t = inputs->tensor<T, 3>();
std::vector<OutputMap> gradient_list_t;
std::vector<InputMap> input_list_t;

Expand All @@ -158,7 +158,7 @@ class CTCLossOp : public OpKernel {
gradient_t.setZero();

// Assumption: the blank index is num_classes - 1
ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0);
ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0);
DeviceBase::CpuWorkerThreads workers =
*ctx->device()->tensorflow_cpu_worker_threads();
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
Expand All @@ -173,9 +173,17 @@ class CTCLossOp : public OpKernel {
bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_;

TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>);
};

REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCLossOp<T>);

REGISTER_CPU(float);
REGISTER_CPU(double);

#undef REGISTER_CPU

} // end namespace tensorflow
17 changes: 10 additions & 7 deletions tensorflow/core/ops/ctc_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ using shape_inference::ShapeHandle;
// CTC is Connectionist Temporal Classification. See util/ctc/ for details.

REGISTER_OP("CTCLoss")
.Input("inputs: float")
.Input("inputs: T")
.Input("labels_indices: int64")
.Input("labels_values: int32")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
.Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.Output("loss: T")
.Output("gradient: T")
.Attr("T: {float, double} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle labels_indices;
Expand Down Expand Up @@ -62,13 +63,14 @@ REGISTER_OP("CTCLoss")
});

REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: float")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("merge_repeated: bool = false")
.Output("decoded_indices: int64")
.Output("decoded_values: int64")
.Output("decoded_shape: int64")
.Output("log_probability: float")
.Output("log_probability: T")
.Attr("T: {float, double} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;
Expand All @@ -90,15 +92,16 @@ REGISTER_OP("CTCGreedyDecoder")
});

REGISTER_OP("CTCBeamSearchDecoder")
.Input("inputs: float")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("beam_width: int >= 1")
.Attr("top_paths: int >= 1")
.Attr("merge_repeated: bool = true")
.Output("decoded_indices: top_paths * int64")
.Output("decoded_values: top_paths * int64")
.Output("decoded_shape: top_paths * int64")
.Output("log_probability: float")
.Output("log_probability: T")
.Attr("T: {float, double} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;
Expand Down