Skip to content

Commit

Permalink
Add a SortListOfSparseCoreCooTensors op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627519922
  • Loading branch information
pineapplejuice233 authored and tensorflower-gardener committed Apr 23, 2024
1 parent 9515a6c commit eec9480
Show file tree
Hide file tree
Showing 12 changed files with 456 additions and 3 deletions.
31 changes: 31 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
Expand Up @@ -2457,4 +2457,35 @@ def TF_ConvertToListOfSparseCoreCooTensorsOp : TF_Op<"ConvertToListOfSparseCoreC

TF_DerivedResultSizeAttr num_sc_per_chip = TF_DerivedResultSizeAttr<0>;
}


def TF_SortListOfSparseCoreCooTensorsOp : TF_Op<"SortListOfSparseCoreCooTensors", [Pure, SameVariadicOperandSize]> {
let summary = "An op which sorts each COO tensors in the list by which SparseCore the id will go to. This op should be used along with the ConvertToSparseCoreCsrWrappedCooTensorOp.";

let arguments = (ins
Variadic<TF_Int32Tensor>:$row_ids_list,
Variadic<TF_Int32Tensor>:$col_ids_list,
Variadic<TF_Float32Tensor>:$gains_list,

I64ArrayAttr:$sample_count_list,
I64ArrayAttr:$col_offset_list,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$num_replica,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$table_vocab_size,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$feature_width,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$num_sc_per_chip,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$max_ids_per_sparse_core,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$max_unique_ids_per_sparse_core,
StrAttr:$table_name
);

let results = (outs
TF_Int32Tensor:$sorted_row_ids,
TF_Int32Tensor:$sorted_col_ids,
TF_Float32Tensor:$sorted_gains,
TF_Int32Tensor:$id_counts
);

// N represents the number of COO tensors in the list.
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>;
}
#endif // TF_OPS
Expand Up @@ -136,7 +136,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) {
// a new op, we should expect these to change too.
EXPECT_EQ(mlir_lowering_count, 67);
EXPECT_EQ(tf2xla_fallback_count, 316);
EXPECT_EQ(non_categorized_count, 425);
EXPECT_EQ(non_categorized_count, 426);
}

// Just a counter test to see which ops have duplicate lowerings. This isn't a
Expand Down
@@ -0,0 +1,4 @@
op {
graph_op_name: "SortListOfSparseCoreCooTensors"
visibility: HIDDEN
}
@@ -0,0 +1,4 @@
op {
graph_op_name: "SortListOfSparseCoreCooTensors"
visibility: HIDDEN
}
257 changes: 257 additions & 0 deletions tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc
Expand Up @@ -1253,4 +1253,261 @@ REGISTER_KERNEL_BUILDER(
Name("ConvertToListOfSparseCoreCooTensors").Device(DEVICE_CPU),
ConvertToListOfSparseCoreCooTensorsOp)

SortListOfSparseCoreCooTensorsOp::SortListOfSparseCoreCooTensorsOp(
OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_replica", &num_replica_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_count_list", &sample_count_list_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("col_offset_list", &col_offset_list_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_width", &feature_width_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_sc_per_chip", &num_sc_per_chip_));

// Col offset needs to be sorted.
for (int i = 0; i < col_offset_list_.size(); ++i) {
col_offset_to_feature_id_[col_offset_list_[i]].push_back(i);
}

num_physical_replica_ = num_replica_ * num_sc_per_chip_;

OP_REQUIRES(ctx, IsPowerOfTwo(num_physical_replica_),
absl::FailedPreconditionError(
"Expected num_physical_replica to be a power of two"));

num_physical_replica_bit_ = std::log2(num_physical_replica_);

OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_ids_per_sparse_core", &max_ids_per_sparse_core_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_unique_ids_per_sparse_core",
&max_unique_ids_per_sparse_core_));

OP_REQUIRES(
ctx, max_ids_per_sparse_core_ > 0,
absl::InvalidArgumentError("max_ids_per_sparse_core must be > 0"));
OP_REQUIRES(
ctx, max_unique_ids_per_sparse_core_ > 0,
absl::InvalidArgumentError("max_unique_ids_per_sparse_core must be > 0"));
}

void SortListOfSparseCoreCooTensorsOp::Compute(OpKernelContext* ctx) {
OpInputList row_ids_list;
OpInputList col_ids_list;
OpInputList gains_list;
OP_REQUIRES_OK(ctx, ctx->input_list("row_ids_list", &row_ids_list));
OP_REQUIRES_OK(ctx, ctx->input_list("col_ids_list", &col_ids_list));
OP_REQUIRES_OK(ctx, ctx->input_list("gains_list", &gains_list));

const int32_t per_sparse_core_batch_size =
absl::c_accumulate(sample_count_list_, 0);

const int32_t num_input_feature_group = col_offset_to_feature_id_.size();

std::vector<std::unique_ptr<uint64_t[]>> col_ids_index_list(
num_input_feature_group);

const int32_t num_physical_replica_mod = (1 << num_physical_replica_bit_) - 1;

Tensor* id_counts_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
"id_counts", TensorShape({num_physical_replica_ + 1}),
&id_counts_tensor));
int32_t* id_counts_tensor_ptr = id_counts_tensor->flat<int32_t>().data();
*id_counts_tensor_ptr = 0;

std::vector<int32_t> total_id_counter(num_physical_replica_);
std::vector<int32_t> total_unique_id_counter(num_physical_replica_);
std::vector<std::unique_ptr<uint32_t[]>> dedup_ids_index_mapping(
num_input_feature_group);
std::vector<std::unique_ptr<float[]>> gains_after_dedup(
num_input_feature_group);

std::vector<int32_t> total_id_counts(num_input_feature_group);
std::vector<std::unique_ptr<int32_t[]>> updated_row_ids(
num_input_feature_group);
std::vector<std::unique_ptr<int32_t[]>> updated_col_ids(
num_input_feature_group);
std::vector<std::unique_ptr<float[]>> updated_gains(num_input_feature_group);

// Concatenate the input features together if they are mapped to the same
// table.
int32_t feature_group_id = 0;
for (const auto& [col_offset, feature_id_list] : col_offset_to_feature_id_) {
int32_t total_id_count = 0;
for (int32_t feature_id : feature_id_list) {
total_id_count += col_ids_list[feature_id].NumElements();
}
total_id_counts[feature_group_id] = total_id_count;
updated_row_ids[feature_group_id] =
absl::make_unique_for_overwrite<int32_t[]>(total_id_count);
updated_col_ids[feature_group_id] =
absl::make_unique_for_overwrite<int32_t[]>(total_id_count);
updated_gains[feature_group_id] =
absl::make_unique_for_overwrite<float[]>(total_id_count);
int32_t tmp_size = 0;
for (int32_t feature_id : feature_id_list) {
int32_t feature_id_count = row_ids_list[feature_id].NumElements();
std::copy_n(row_ids_list[feature_id].flat<int32_t>().data(),
feature_id_count,
updated_row_ids[feature_group_id].get() + tmp_size);
std::copy_n(col_ids_list[feature_id].flat<int32_t>().data(),
feature_id_count,
updated_col_ids[feature_group_id].get() + tmp_size);
std::copy_n(gains_list[feature_id].flat<float>().data(), feature_id_count,
updated_gains[feature_group_id].get() + tmp_size);
tmp_size += feature_id_count;
}
feature_group_id++;
}

for (int feature_group_id = 0; feature_group_id < num_input_feature_group;
++feature_group_id) {
int32_t total_id_count = total_id_counts[feature_group_id];
dedup_ids_index_mapping[feature_group_id] =
absl::make_unique_for_overwrite<uint32_t[]>(total_id_count);

gains_after_dedup[feature_group_id] =
absl::make_unique_for_overwrite<float[]>(total_id_count);

uint32_t* per_feature_dedup_ids_index_mapping =
dedup_ids_index_mapping[feature_group_id].get();

float* per_feature_gains_after_dedup =
gains_after_dedup[feature_group_id].get();
const int32_t* row_ids_ptr = updated_row_ids[feature_group_id].get();
const int32_t* col_ids_ptr = updated_col_ids[feature_group_id].get();
const float* gains_ptr = updated_gains[feature_group_id].get();
col_ids_index_list[feature_group_id] =
absl::make_unique_for_overwrite<uint64_t[]>(total_id_count);
uint64_t* per_feature_col_ids_index_list =
col_ids_index_list[feature_group_id].get();
for (int32_t index = 0; index < total_id_count; ++index) {
per_feature_col_ids_index_list[index] =
(static_cast<uint64_t>(*(col_ids_ptr + index)) << 32) + index;
}
hwy::VQSort(per_feature_col_ids_index_list, total_id_count,
hwy::SortAscending());

// Loop through the col ids to count the ids and unique ids.
int32_t previous_col_id = -1;
int32_t previous_row_id = -1;
uint32_t previous_id_array_index = 0;
for (int32_t index = 0; index < total_id_count; ++index) {
uint64_t item = per_feature_col_ids_index_list[index];
int32 col_id = item >> 32;
uint32_t id_array_index = item & 0xffffffff;
int32_t row_id = *(row_ids_ptr + id_array_index);
// If the row ids and col ids are both same as the previous one,
// dedup the id by adding the gains.
if (row_id != previous_row_id || col_id != previous_col_id) {
per_feature_dedup_ids_index_mapping[id_array_index] = id_array_index;
per_feature_gains_after_dedup[id_array_index] =
*(gains_ptr + id_array_index);
uint32_t replica_id = col_id & num_physical_replica_mod;
total_id_counter[replica_id]++;
if (col_id != previous_col_id) total_unique_id_counter[replica_id]++;
} else {
// Dedup the id if both row id and col id is the same.
uint32_t parent_idx =
per_feature_dedup_ids_index_mapping[previous_id_array_index];
per_feature_dedup_ids_index_mapping[id_array_index] = parent_idx;
per_feature_gains_after_dedup[parent_idx] +=
*(gains_ptr + id_array_index);
}
previous_id_array_index = id_array_index;
previous_col_id = col_id;
previous_row_id = row_id;
}
}

for (int replica_id = 0; replica_id < num_physical_replica_; ++replica_id) {
// If the one of the replica (unique) id count is larger than the max
// setting, then we will fail the op.
OP_REQUIRES(
ctx, total_id_counter[replica_id] <= max_ids_per_sparse_core_,
absl::InvalidArgumentError(absl::StrCat(
"Sparse core ", replica_id, " gets ", total_id_counter[replica_id],
" ids while the max ids per sparse core is set to be ",
max_ids_per_sparse_core_, " for table ", table_name_)));
OP_REQUIRES(
ctx,
total_unique_id_counter[replica_id] <= max_unique_ids_per_sparse_core_,
absl::InvalidArgumentError(absl::StrCat(
"Sparse core ", replica_id, " gets ",
total_unique_id_counter[replica_id],
" unique ids while the max unique ids per sparse core is set "
"to be ",
max_unique_ids_per_sparse_core_, " for table ", table_name_)));
*(id_counts_tensor_ptr + replica_id + 1) =
total_id_counter[replica_id] + *(id_counts_tensor_ptr + replica_id);
}

const int32_t updated_total_id_count =
absl::c_accumulate(total_id_counter, 0);

Tensor* sorted_row_ids_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("sorted_row_ids",
TensorShape({updated_total_id_count}),
&sorted_row_ids_tensor));
Tensor* sorted_col_ids_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("sorted_col_ids",
TensorShape({updated_total_id_count}),
&sorted_col_ids_tensor));
Tensor* sorted_gains_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
"sorted_gains", TensorShape({updated_total_id_count}),
&sorted_gains_tensor));

int32_t* sorted_row_ids_tensor_ptr =
sorted_row_ids_tensor->flat<int32_t>().data();
int32_t* sorted_col_ids_tensor_ptr =
sorted_col_ids_tensor->flat<int32_t>().data();
float* sorted_gains_tensor_ptr = sorted_gains_tensor->flat<float>().data();

std::vector<int32_t> per_physical_replica_index(num_physical_replica_);

for (int feature_group_id = 0; feature_group_id < num_input_feature_group;
++feature_group_id) {
const int32_t* row_ids_ptr = updated_row_ids[feature_group_id].get();

const uint32_t* per_feature_dedup_ids_index_mapping =
dedup_ids_index_mapping[feature_group_id].get();

const float* per_feature_gains_after_dedup =
gains_after_dedup[feature_group_id].get();

const uint64_t* per_feature_col_ids_index_list =
col_ids_index_list[feature_group_id].get();

const int32_t total_id_count = total_id_counts[feature_group_id];

for (int32_t index = 0; index < total_id_count; ++index) {
uint64_t item = per_feature_col_ids_index_list[index];
uint32_t id_array_index = item & 0xffffffff;
if (id_array_index !=
per_feature_dedup_ids_index_mapping[id_array_index]) {
continue;
}
int32_t col_id = item >> 32;
int32_t replica_id = col_id & num_physical_replica_mod;

int32_t main_index = *(id_counts_tensor_ptr + replica_id) +
per_physical_replica_index[replica_id];
*(sorted_row_ids_tensor_ptr + main_index) =
*(row_ids_ptr + id_array_index) % per_sparse_core_batch_size;
*(sorted_col_ids_tensor_ptr + main_index) =
col_id >> num_physical_replica_bit_;
// Use the updated gains instead.
*(sorted_gains_tensor_ptr + main_index) =
per_feature_gains_after_dedup[id_array_index];
++per_physical_replica_index[replica_id];
}
}
}

REGISTER_KERNEL_BUILDER(
Name("SortListOfSparseCoreCooTensors").Device(DEVICE_CPU),
SortListOfSparseCoreCooTensorsOp)

} // namespace tensorflow
30 changes: 29 additions & 1 deletion tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h
Expand Up @@ -16,12 +16,15 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_

#include <cstdint>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h"
Expand Down Expand Up @@ -179,6 +182,31 @@ class ConvertToListOfSparseCoreCooTensorsOp : public OpKernel {
std::string combiner_;
};

class SortListOfSparseCoreCooTensorsOp : public OpKernel {
public:
explicit SortListOfSparseCoreCooTensorsOp(OpKernelConstruction* ctx);
~SortListOfSparseCoreCooTensorsOp() override = default;
SortListOfSparseCoreCooTensorsOp(const SortListOfSparseCoreCooTensorsOp&) =
delete;
SortListOfSparseCoreCooTensorsOp& operator=(
const SortListOfSparseCoreCooTensorsOp&) = delete;

void Compute(OpKernelContext* ctx) override;

private:
int32_t num_sc_per_chip_;
int32_t feature_width_;
int32_t num_replica_;
int32_t num_physical_replica_;
int32_t num_physical_replica_bit_;
int32_t max_ids_per_sparse_core_;
int32_t max_unique_ids_per_sparse_core_;
std::string table_name_;
std::vector<int32_t> sample_count_list_;
std::vector<int32_t> col_offset_list_;
std::map<int32_t, std::vector<int32_t>> col_offset_to_feature_id_;
};

} // namespace tensorflow

#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_
1 change: 1 addition & 0 deletions tensorflow/core/tpu/ops/BUILD
Expand Up @@ -190,6 +190,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_xla//xla:util",
],
alwayslink = 1,
Expand Down

0 comments on commit eec9480

Please sign in to comment.