Skip to content

Commit

Permalink
Add a new op to convert the sorted coo tensor into sparse core CSR wr…
Browse files Browse the repository at this point in the history
…apped COO format.

PiperOrigin-RevId: 627555672
  • Loading branch information
pineapplejuice233 authored and tensorflower-gardener committed Apr 24, 2024
1 parent b9ca938 commit 0d6a9e6
Show file tree
Hide file tree
Showing 12 changed files with 746 additions and 1 deletion.
34 changes: 34 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
Expand Up @@ -2488,4 +2488,38 @@ def TF_SortListOfSparseCoreCooTensorsOp : TF_Op<"SortListOfSparseCoreCooTensors"
// N represents the number of COO tensors in the list.
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>;
}


def TF_ConvertToSparseCoreCsrWrappedCooTensorOp : TF_Op<"ConvertToSparseCoreCsrWrappedCooTensorOp", [Pure, SameVariadicOperandSize]> {
let summary = "An op which converts the sorted coo tensor into sparse core CSR wrapped COO format.";

let arguments = (ins
Variadic<TF_Int32Tensor>:$sorted_row_ids_list,
Variadic<TF_Int32Tensor>:$sorted_col_ids_list,
Variadic<TF_Float32Tensor>:$sorted_gains_list,
Variadic<TF_Int32Tensor>:$id_counts_list,
TF_Int64Tensor:$splits,

ConfinedAttr<I64Attr, [IntMinValue<1>]>:$sample_count_per_sc,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$num_replica,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$max_minibatches_per_sc,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$max_ids_per_chip_per_sample,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$table_vocab_size,
ConfinedAttr<I64Attr, [IntMinValue<1>]>:$feature_width,
StrAttr:$table_name,
BoolAttr:$allow_id_dropping
);

let results = (outs
TF_Int32Tensor:$row_pointers,
TF_Int32Tensor:$sorted_sample_ids,
TF_Int32Tensor:$sorted_token_ids,
TF_Float32Tensor:$sorted_gains,
TF_Int32Tensor:$row_pointers_unpadded_size,
TF_Int32Tensor:$ids_unpadded_size,
TF_Int32Tensor:$num_minibatches_per_sc
);

TF_DerivedOperandSizeAttr num_sc_per_chip = TF_DerivedOperandSizeAttr<1>;
}
#endif // TF_OPS
Expand Up @@ -136,7 +136,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) {
// a new op, we should expect these to change too.
EXPECT_EQ(mlir_lowering_count, 67);
EXPECT_EQ(tf2xla_fallback_count, 316);
EXPECT_EQ(non_categorized_count, 426);
EXPECT_EQ(non_categorized_count, 427);
}

// Just a counter test to see which ops have duplicate lowerings. This isn't a
Expand Down
@@ -0,0 +1,4 @@
op {
graph_op_name: "ConvertToSparseCoreCsrWrappedCooTensor"
visibility: HIDDEN
}
@@ -0,0 +1,4 @@
op {
graph_op_name: "ConvertToSparseCoreCsrWrappedCooTensor"
visibility: HIDDEN
}
261 changes: 261 additions & 0 deletions tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc
Expand Up @@ -1510,4 +1510,265 @@ REGISTER_KERNEL_BUILDER(
Name("SortListOfSparseCoreCooTensors").Device(DEVICE_CPU),
SortListOfSparseCoreCooTensorsOp)

ConvertToSparseCoreCsrWrappedCooTensorOp::
ConvertToSparseCoreCsrWrappedCooTensorOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_replica", &num_replica_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("sample_count_per_sc", &sample_count_per_sc_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_minibatches_per_sc", &max_minibatches_per_sc_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_ids_per_chip_per_sample",
&max_ids_per_chip_per_sample_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_vocab_size", &table_vocab_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_width", &feature_width_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_sc_per_chip", &num_sc_per_chip_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("allow_id_dropping", &allow_id_dropping_));

device_name_ = ctx->device()->name();
}

void ConvertToSparseCoreCsrWrappedCooTensorOp::Compute(OpKernelContext* ctx) {
OpInputList sorted_row_ids_list;
OpInputList sorted_col_ids_list;
OpInputList sorted_gains_list;
OpInputList id_counts_list;
OP_REQUIRES_OK(ctx,
ctx->input_list("sorted_row_ids_list", &sorted_row_ids_list));
OP_REQUIRES_OK(ctx,
ctx->input_list("sorted_col_ids_list", &sorted_col_ids_list));
OP_REQUIRES_OK(ctx, ctx->input_list("sorted_gains_list", &sorted_gains_list));
OP_REQUIRES_OK(ctx, ctx->input_list("id_counts_list", &id_counts_list));
const Tensor* splits;
OP_REQUIRES_OK(ctx, ctx->input("splits", &splits));

OP_REQUIRES(
ctx, sorted_row_ids_list.size() == num_sc_per_chip_,
absl::InvalidArgumentError(
"Sorted row ids list size is not equal to the num sc per chip."));
const int32_t num_physical_replica = num_replica_ * num_sc_per_chip_;

const int32_t max_ids_per_chip =
max_ids_per_chip_per_sample_ * sample_count_per_sc_ * num_sc_per_chip_;

const int max_division_level = GetMinibatchMaxDivisionLevel();

const int32_t kMaxDivisions = 1 << max_division_level;

const int64_t* splits_tensor_ptr = splits->flat<int64_t>().data();

int64_t binary_splits = 0;
for (int i = 0; i < splits->NumElements(); ++i) {
binary_splits |= *(splits_tensor_ptr + i);
}

std::vector<int> bucket_splits =
ConvertBinarySplitsToBucketSplits(binary_splits, max_division_level);

// Compute the number of minibatch per sparsecore.
const int32_t num_minibatch_per_sc = bucket_splits.size() + 1;

int32_t total_id_count = 0;

for (int sc_id = 0; sc_id < num_sc_per_chip_; ++sc_id) {
OP_REQUIRES(
ctx,
id_counts_list[sc_id].NumElements() == num_physical_replica + 1 ||
id_counts_list[sc_id].NumElements() ==
num_physical_replica * kMaxDivisions + 1,
absl::InvalidArgumentError(absl::StrCat(
"The id counts should have either ", num_physical_replica + 1,
" elements when there is no minibatching or ",
num_physical_replica * kMaxDivisions + 1,
" elements when there are multiple minibatches for each "
"sparsecore. But instead got ",
id_counts_list[sc_id].NumElements(), " elements.")));
total_id_count += *(id_counts_list[sc_id].flat<int32_t>().data() +
id_counts_list[sc_id].NumElements() - 1);
}

// We use the number of elements in the id_counts_list to determine whether
// minibatching is needed rather than num_minibatch_per_sc. This is because
// the minibatch logic can be triggered but only one minibatch is
// computed at the end with some ids getting dropped.
const bool is_minibatching = id_counts_list[0].NumElements() ==
num_physical_replica * kMaxDivisions + 1;

const int32_t total_num_minibatch = num_minibatch_per_sc * num_sc_per_chip_;

bucket_splits.insert(bucket_splits.begin(), 0);
bucket_splits.push_back(kMaxDivisions);

const int32_t xla_pad_size = 8;

OP_REQUIRES(
ctx, max_ids_per_chip % xla_pad_size == 0,
absl::InvalidArgumentError(absl::StrCat(
"The max_ids_per_chip is set to be ", max_ids_per_chip,
" which is not divisible by the xla_pad_size ", xla_pad_size, " .")));

const int32_t padded_row_pointers_size_per_sc =
xla::RoundUpTo<int32_t>(num_physical_replica, xla_pad_size);

Tensor* row_pointers_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(
"row_pointers",
TensorShape({max_minibatches_per_sc_ * num_sc_per_chip_ *
padded_row_pointers_size_per_sc}),
&row_pointers_tensor));

Tensor* sorted_sample_ids_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output("sorted_sample_ids",
TensorShape({max_ids_per_chip}),
&sorted_sample_ids_tensor));
Tensor* sorted_token_ids_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output("sorted_token_ids",
TensorShape({max_ids_per_chip}),
&sorted_token_ids_tensor));
Tensor* sorted_gains_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("sorted_gains", TensorShape({max_ids_per_chip}),
&sorted_gains_tensor));

int32_t* row_pointers_tensor_ptr =
row_pointers_tensor->flat<int32_t>().data();
int32_t* sorted_sample_ids_tensor_ptr =
sorted_sample_ids_tensor->flat<int32_t>().data();
int32_t* sorted_token_ids_tensor_ptr =
sorted_token_ids_tensor->flat<int32_t>().data();
float* sorted_gains_tensor_ptr = sorted_gains_tensor->flat<float>().data();

// This packed id count is used to track how many ids we have packed into
// the output tensor and based on this we would know how many ids that we
// dropped.
int32_t packed_id_count = 0;

int32_t global_index = 0;
int32_t row_pointers_index = 0;
for (int sc_id = 0; sc_id < num_sc_per_chip_; ++sc_id) {
const int32_t* row_ids_tensor_ptr =
sorted_row_ids_list[sc_id].flat<int32_t>().data();
const int32_t* col_ids_tensor_ptr =
sorted_col_ids_list[sc_id].flat<int32_t>().data();
const float* gains_tensor_ptr =
sorted_gains_list[sc_id].flat<float>().data();
const int32_t* id_counts_tensor_ptr =
id_counts_list[sc_id].flat<int32_t>().data();
for (int bucket_id = 1; bucket_id < bucket_splits.size(); ++bucket_id) {
for (int replica_id = 0; replica_id < num_physical_replica;
++replica_id) {
int start_division_pos, end_division_pos;
if (is_minibatching) {
start_division_pos =
replica_id * kMaxDivisions + bucket_splits[bucket_id - 1];
end_division_pos =
replica_id * kMaxDivisions + bucket_splits[bucket_id];
} else {
start_division_pos = replica_id;
end_division_pos = replica_id + 1;
}
const int32_t start_pos = *(id_counts_tensor_ptr + start_division_pos);
const int32_t end_pos = *(id_counts_tensor_ptr + end_division_pos);

const int32_t token_id_count = end_pos - start_pos;

if (global_index + token_id_count > max_ids_per_chip) {
if (allow_id_dropping_) {
const int32_t copy_id_count =
std::min(max_ids_per_chip - global_index, token_id_count);
std::copy_n(col_ids_tensor_ptr + start_pos, copy_id_count,
sorted_token_ids_tensor_ptr + global_index);
std::copy_n(row_ids_tensor_ptr + start_pos, copy_id_count,
sorted_sample_ids_tensor_ptr + global_index);
std::copy_n(gains_tensor_ptr + start_pos, copy_id_count,
sorted_gains_tensor_ptr + global_index);
packed_id_count += copy_id_count;
global_index = max_ids_per_chip;
} else {
const int32_t remain_id_count = total_id_count - packed_id_count;
ctx->CtxFailure(absl::InvalidArgumentError(absl::StrCat(
"The max_ids_per_chip is set to be ", max_ids_per_chip,
" which is not going to fit all ids. The remaining id count "
"is ",
remain_id_count,
" . Please consider setting the allow_id_dropping to be "
"true. ")));
return;
}
} else {
std::copy_n(col_ids_tensor_ptr + start_pos, token_id_count,
sorted_token_ids_tensor_ptr + global_index);
std::copy_n(row_ids_tensor_ptr + start_pos, token_id_count,
sorted_sample_ids_tensor_ptr + global_index);
std::copy_n(gains_tensor_ptr + start_pos, token_id_count,
sorted_gains_tensor_ptr + global_index);

global_index += token_id_count;
packed_id_count += token_id_count;
}

*(row_pointers_tensor_ptr + row_pointers_index) = global_index;
int32 num_ids_to_pad_per_replica =
xla::RoundUpTo<int32_t>(global_index, xla_pad_size) - global_index;

std::fill_n(sorted_token_ids_tensor_ptr + global_index,
num_ids_to_pad_per_replica, kXlaPadValue);
std::fill_n(sorted_sample_ids_tensor_ptr + global_index,
num_ids_to_pad_per_replica, kXlaPadValue);
std::fill_n(sorted_gains_tensor_ptr + global_index,
num_ids_to_pad_per_replica, kXlaPadValue);

global_index += num_ids_to_pad_per_replica;
++row_pointers_index;
}
// Pad the row_pointers to be memory aligned.
int32 num_row_pointers_to_pad =
xla::RoundUpTo<int32>(row_pointers_index, xla_pad_size) -
row_pointers_index;
std::fill_n(row_pointers_tensor_ptr + row_pointers_index,
num_row_pointers_to_pad, global_index);
row_pointers_index += num_row_pointers_to_pad;
}
}
int32_t ids_unpadded_size = global_index;

if (packed_id_count < total_id_count) {
const int32_t dropped_id_count = total_id_count - packed_id_count;
LOG(WARNING) << "Table " << table_name_ << " is dropping "
<< dropped_id_count
<< " ids so that the produced CsrWrappedCooTensor can be fit "
"in static bound of "
<< max_ids_per_chip
<< " . This could potentially impact the model quality.";
}

int32 row_pointers_unpadded_size =
total_num_minibatch * padded_row_pointers_size_per_sc;

Tensor* num_minibatches_per_sc_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("num_minibatches_per_sc", TensorShape({}),
&num_minibatches_per_sc_tensor));

Tensor* row_pointers_unpadded_size_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("row_pointers_unpadded_size", TensorShape({}),
&row_pointers_unpadded_size_tensor));

Tensor* ids_unpadded_size_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output("ids_unpadded_size", TensorShape({}),
&ids_unpadded_size_tensor));

num_minibatches_per_sc_tensor->flat<int32>()(0) = num_minibatch_per_sc;
row_pointers_unpadded_size_tensor->flat<int32>()(0) =
row_pointers_unpadded_size;
ids_unpadded_size_tensor->flat<int32>()(0) = ids_unpadded_size;
}

REGISTER_KERNEL_BUILDER(
Name("ConvertToSparseCoreCsrWrappedCooTensor").Device(DEVICE_CPU),
ConvertToSparseCoreCsrWrappedCooTensorOp)

} // namespace tensorflow
24 changes: 24 additions & 0 deletions tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h
Expand Up @@ -207,6 +207,30 @@ class SortListOfSparseCoreCooTensorsOp : public OpKernel {
std::map<int32_t, std::vector<int32_t>> col_offset_to_feature_id_;
};

class ConvertToSparseCoreCsrWrappedCooTensorOp : public OpKernel {
public:
explicit ConvertToSparseCoreCsrWrappedCooTensorOp(OpKernelConstruction* ctx);
~ConvertToSparseCoreCsrWrappedCooTensorOp() override = default;
ConvertToSparseCoreCsrWrappedCooTensorOp(
const ConvertToSparseCoreCsrWrappedCooTensorOp&) = delete;
ConvertToSparseCoreCsrWrappedCooTensorOp& operator=(
const ConvertToSparseCoreCsrWrappedCooTensorOp&) = delete;

void Compute(OpKernelContext* ctx) override;

private:
int32_t num_sc_per_chip_;
int32_t table_vocab_size_;
int32_t feature_width_;
int32_t num_replica_;
int32_t sample_count_per_sc_;
int32_t max_minibatches_per_sc_;
int32_t max_ids_per_chip_per_sample_;
bool allow_id_dropping_;
std::string table_name_;
std::string device_name_;
};

} // namespace tensorflow

#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_

0 comments on commit 0d6a9e6

Please sign in to comment.