Skip to content

Commit

Permalink
Adds performance enhancements for sparse embedding lookups.
Browse files Browse the repository at this point in the history
  • Loading branch information
philipphack committed Feb 16, 2023
1 parent 3e18beb commit 00ee1a9
Show file tree
Hide file tree
Showing 9 changed files with 604 additions and 209 deletions.
108 changes: 80 additions & 28 deletions tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
Expand Up @@ -40,8 +40,12 @@ using GPUDevice = Eigen::GpuDevice;

namespace functor {

template <typename T, typename Tindex>
struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
template <typename T, typename Tindex, bool Compressed>
struct SparseFillEmptyRows<CPUDevice, T, Tindex, Compressed> {
private:
static constexpr int IndicesRank = Compressed ? 1 : 2;

public:
Status operator()(OpKernelContext* context, const Tensor& default_value_t,
const Tensor& indices_t, const Tensor& values_t,
const Tensor& dense_shape_t,
Expand All @@ -53,7 +57,7 @@ struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
const int kReverseIndexMapOutput = 3;

const T& default_value = default_value_t.scalar<T>()();
const auto indices = indices_t.matrix<Tindex>();
const auto indices = indices_t.tensor<Tindex, IndicesRank>();
const auto values = values_t.vec<T>();
const auto dense_shape = dense_shape_t.vec<Tindex>();

Expand All @@ -80,7 +84,7 @@ struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
reverse_index_map = reverse_index_map_t->vec<Tindex>().data();
}

int rank = indices_t.shape().dim_size(1);
const int rank = Compressed ? 1 : indices_t.shape().dim_size(1);

if (dense_rows == 0) {
if (N != 0) {
Expand All @@ -103,11 +107,19 @@ struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
return OkStatus();
}

auto vec_or_matrix = [](auto tensor, int index1, int index2) -> auto& {
if constexpr (Compressed) {
return tensor(index1);
} else {
return tensor(index1, index2);
}
};

bool rows_are_ordered = true;
Tindex last_indices_row = 0;
std::vector<Tindex> csr_offset(dense_rows, 0);
for (int i = 0; i < N; ++i) {
const Tindex row = indices(i, 0);
const Tindex row = vec_or_matrix(indices, i, 0);
if (row < 0 || row >= dense_rows) {
return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row,
" >= ", dense_rows);
Expand Down Expand Up @@ -164,11 +176,12 @@ struct SparseFillEmptyRows<CPUDevice, T, Tindex> {

// Fill in values for rows that are not missing
for (Tindex i = 0; i < N; ++i) {
const Tindex row = indices(i, 0);
const Tindex row = vec_or_matrix(indices, i, 0);
Tindex& offset = filled_count[row];
const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset;
offset++; // Increment the filled count for this row.
std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
std::copy_n(&vec_or_matrix(indices, i, 0), rank,
&vec_or_matrix(output_indices, output_i, 0));
output_values(output_i) = values(i);
// We'll need this reverse index map to backprop correctly.
if (reverse_index_map) {
Expand All @@ -183,9 +196,9 @@ struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1];
// Remaining index values were set to zero already.
// Just need to set the row index in the right location.
output_indices(starting_index, 0) = row;
vec_or_matrix(output_indices, starting_index, 0) = row;
for (Tindex col = 1; col < rank; ++col) {
output_indices(starting_index, col) = 0;
vec_or_matrix(output_indices, starting_index, col) = 0;
}
output_values(starting_index) = default_value;
}
Expand All @@ -200,7 +213,7 @@ struct SparseFillEmptyRows<CPUDevice, T, Tindex> {

namespace {

template <typename Device, typename T, typename Tindex>
template <typename Device, typename T, typename Tindex, bool Compressed>
void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
AsyncOpKernel::DoneCallback done = nullptr) {
// Note that setting this empty lambda as the default parameter value directly
Expand All @@ -209,6 +222,8 @@ void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
done = [] {};
}

static constexpr int IndicesRank = Compressed ? 1 : 2;

const int kIndicesInput = 0;
const int kValuesInput = 1;
const int kDenseShapeInput = 2;
Expand All @@ -224,10 +239,17 @@ void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
errors::InvalidArgument("dense_shape must be a vector, saw: ",
dense_shape_t.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(indices_t.shape()),
errors::InvalidArgument("indices must be a matrix, saw: ",
indices_t.shape().DebugString()),
done);
if (Compressed) {
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(indices_t.shape()),
errors::InvalidArgument("indices must be a vector, saw: ",
indices_t.shape().DebugString()),
done);
} else {
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(indices_t.shape()),
errors::InvalidArgument("indices must be a matrix, saw: ",
indices_t.shape().DebugString()),
done);
}
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(values_t.shape()),
errors::InvalidArgument("values must be a vector, saw: ",
values_t.shape().DebugString()),
Expand All @@ -249,7 +271,8 @@ void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
errors::InvalidArgument("Dense shape cannot be empty."),
done);

using FunctorType = functor::SparseFillEmptyRows<Device, T, Tindex>;
using FunctorType =
functor::SparseFillEmptyRows<Device, T, Tindex, Compressed>;
OP_REQUIRES_OK_ASYNC(context,
FunctorType()(context, default_value_t, indices_t,
values_t, dense_shape_t, done),
Expand All @@ -262,11 +285,20 @@ template <typename Device, typename T, typename Tindex>
class SparseFillEmptyRowsOp : public OpKernel {
public:
explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
: OpKernel(context) {}
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("compressed", &compressed_));
}

void Compute(OpKernelContext* context) override {
SparseFillEmptyRowsOpImpl<Device, T, Tindex>(context);
if (compressed_) {
SparseFillEmptyRowsOpImpl<Device, T, Tindex, true>(context);
} else {
SparseFillEmptyRowsOpImpl<Device, T, Tindex, false>(context);
}
}

private:
bool compressed_;
};

#define REGISTER_KERNELS(D, T, Tindex) \
Expand All @@ -291,35 +323,55 @@ template <typename T, typename Tindex>
class SparseFillEmptyRowsGPUOp : public AsyncOpKernel {
public:
explicit SparseFillEmptyRowsGPUOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
: AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("compressed", &compressed_));
}

void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
SparseFillEmptyRowsOpImpl<GPUDevice, T, Tindex>(context, done);
if (compressed_) {
SparseFillEmptyRowsOpImpl<GPUDevice, T, Tindex, true>(context, done);
} else {
SparseFillEmptyRowsOpImpl<GPUDevice, T, Tindex, false>(context, done);
}
}
};

#define REGISTER_KERNELS(T, Tindex) \
REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows") \
.Device(DEVICE_GPU) \
.HostMemory("dense_shape") \
.TypeConstraint<T>("T"), \
SparseFillEmptyRowsGPUOp<T, Tindex>)
private:
bool compressed_;
};

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, Tindex) \
template <> \
Status SparseFillEmptyRows<GPUDevice, T, Tindex>::operator()( \
Status SparseFillEmptyRows<GPUDevice, T, Tindex, true>::operator()( \
OpKernelContext* context, const Tensor& default_value_t, \
const Tensor& indices_t, const Tensor& values_t, \
const Tensor& dense_shape_t, typename AsyncOpKernel::DoneCallback done); \
extern template struct SparseFillEmptyRows<GPUDevice, T, Tindex, true>;
#define DECLARE_GPU_SPEC_INT64(T) DECLARE_GPU_SPEC(T, int64_t)
TF_CALL_POD_TYPES(DECLARE_GPU_SPEC_INT64)
#undef DECLARE_GPU_SPEC_INT64
#undef DECLARE_GPU_SPEC
#define DECLARE_GPU_SPEC(T, Tindex) \
template <> \
Status SparseFillEmptyRows<GPUDevice, T, Tindex, false>::operator()( \
OpKernelContext* context, const Tensor& default_value_t, \
const Tensor& indices_t, const Tensor& values_t, \
const Tensor& dense_shape_t, typename AsyncOpKernel::DoneCallback done); \
extern template struct SparseFillEmptyRows<GPUDevice, T, Tindex>;
extern template struct SparseFillEmptyRows<GPUDevice, T, Tindex, false>;
#define DECLARE_GPU_SPEC_INT64(T) DECLARE_GPU_SPEC(T, int64_t)
TF_CALL_POD_TYPES(DECLARE_GPU_SPEC_INT64)
#undef DECLARE_GPU_SPEC_INT64
#undef DECLARE_GPU_SPEC
} // namespace functor

#define REGISTER_KERNELS(T, Tindex) \
REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows") \
.Device(DEVICE_GPU) \
.HostMemory("dense_shape") \
.TypeConstraint<T>("T"), \
SparseFillEmptyRowsGPUOp<T, Tindex>)

#define REGISTER_KERNELS_TINDEX(T) REGISTER_KERNELS(T, int64)
TF_CALL_POD_TYPES(REGISTER_KERNELS_TINDEX)
#undef REGISTER_KERNELS_TINDEX
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/sparse_fill_empty_rows_op.h
Expand Up @@ -24,7 +24,7 @@ namespace tensorflow {

namespace functor {

template <typename Device, typename T, typename Tindex>
template <typename Device, typename T, typename Tindex, bool Compressed>
struct SparseFillEmptyRows {
// Note that the done callback is only used by the GPU implementation.
Status operator()(OpKernelContext* context, const Tensor& default_value_t,
Expand Down
26 changes: 17 additions & 9 deletions tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
Expand Up @@ -165,21 +165,21 @@ __global__ __launch_bounds__(1024) void ScatterNewElementsKernel(

} // namespace

template <typename T, typename Tindex>
struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
template <typename T, typename Tindex, bool Compressed>
struct SparseFillEmptyRows<GPUDevice, T, Tindex, Compressed> {
Status operator()(OpKernelContext* context, const Tensor& default_value_t,
const Tensor& indices_t, const Tensor& values_t,
const Tensor& dense_shape_t,
typename AsyncOpKernel::DoneCallback done) {
const int kEmptyRowIndicatorOutput = 2;

const auto default_value = default_value_t.scalar<T>();
const auto indices = indices_t.matrix<Tindex>();
const auto indices = indices_t.tensor<Tindex, IndicesRank>();
const auto values = values_t.vec<T>();
const auto dense_shape = dense_shape_t.vec<Tindex>();

const Tindex N = indices_t.shape().dim_size(0);
const int rank = indices_t.shape().dim_size(1);
const int rank = Compressed ? 1 : indices_t.shape().dim_size(1);
const Tindex dense_rows = dense_shape(0); // Must be on the host
DataType index_type = DataTypeToEnum<Tindex>::value;
const GPUDevice& device = context->eigen_device<GPUDevice>();
Expand Down Expand Up @@ -429,6 +429,8 @@ struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
}

private:
static constexpr int IndicesRank = Compressed ? 1 : 2;

Status AllocateOutputsExceptEmptyRowIndicator(
OpKernelContext* context, Tindex N, int rank, Tindex num_empty_rows,
Tindex** output_indices, T** output_values, Tindex** reverse_index_map) {
Expand Down Expand Up @@ -456,10 +458,11 @@ struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
return OkStatus();
}

Status ArgSortByRows(OpKernelContext* context, const GPUDevice& device,
Tindex N, int rank, Tindex dense_rows,
typename TTypes<Tindex>::ConstMatrix indices,
Tensor* input_index_map_t) {
Status ArgSortByRows(
OpKernelContext* context, const GPUDevice& device, Tindex N, int rank,
Tindex dense_rows,
typename TTypes<Tindex, IndicesRank>::ConstTensor indices,
Tensor* input_index_map_t) {
DataType index_type = DataTypeToEnum<Tindex>::value;
// Extract row indices into separate array for use as keys for sorting.
Tensor row_indices_t;
Expand All @@ -486,7 +489,12 @@ struct SparseFillEmptyRows<GPUDevice, T, Tindex> {
} // namespace functor

#define DEFINE_INT64(T) \
template struct functor::SparseFillEmptyRows<GPUDevice, T, int64>;
template struct functor::SparseFillEmptyRows<GPUDevice, T, int64, true>;
TF_CALL_POD_TYPES(DEFINE_INT64)
#undef DEFINE_INT64

#define DEFINE_INT64(T) \
template struct functor::SparseFillEmptyRows<GPUDevice, T, int64, false>;
TF_CALL_POD_TYPES(DEFINE_INT64)
#undef DEFINE_INT64

Expand Down
15 changes: 12 additions & 3 deletions tensorflow/core/ops/sparse_ops.cc
Expand Up @@ -609,9 +609,16 @@ REGISTER_OP("SparseFillEmptyRows")
.Output("empty_row_indicator: bool")
.Output("reverse_index_map: int64")
.Attr("T: type")
.Attr("compressed: bool = False")
.SetShapeFn([](InferenceContext* c) {
bool compressed;
TF_RETURN_IF_ERROR(c->GetAttr("compressed", &compressed));
ShapeHandle input_indices = c->input(0);
TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
if (compressed) {
TF_RETURN_IF_ERROR(c->WithRank(input_indices, 1, &input_indices));
} else {
TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
}
ShapeHandle input_values = c->input(1);
TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
ShapeHandle input_shape = c->input(2);
Expand All @@ -621,8 +628,10 @@ REGISTER_OP("SparseFillEmptyRows")
DimensionHandle N = c->Dim(input_indices, 0);
TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
c->Dim(input_shape, 0), &unused_dim));
if (!compressed) {
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
c->Dim(input_shape, 0), &unused_dim));
}
if (c->Value(c->NumElements(input_shape)) == 0)
return errors::InvalidArgument("dense_shape must not be empty");
ShapeHandle output_indices =
Expand Down

0 comments on commit 00ee1a9

Please sign in to comment.