/
ragged_tensor_to_sparse_kernel.cc
228 lines (204 loc) · 9.23 KB
/
ragged_tensor_to_sparse_kernel.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
using errors::InvalidArgument;
template <typename SPLITS_TYPE>
class RaggedTensorToSparseOp : public OpKernel {
public:
using OpKernel::OpKernel;
using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat;
void Compute(OpKernelContext* context) override {
// Read the `rt_nested_splits` input & convert to Eigen tensors.
OpInputList rt_nested_splits_in;
OP_REQUIRES_OK(
context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
const int rt_nested_splits_len = rt_nested_splits_in.size();
DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP.
std::vector<ConstFlatSplits> rt_nested_splits;
rt_nested_splits.reserve(rt_nested_splits_len);
for (int i = 0; i < rt_nested_splits_len; ++i) {
rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>());
}
// Read the `rt_dense_values` input.
const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len);
OP_REQUIRES_OK(context,
ValidateInputs(rt_nested_splits, rt_dense_values_in));
// Assemble each value in `sparse_indices` using three parts:
// - `index_prefix` is the index in dimensions up through the last ragged
// dimension.
// - `index_middle` is the index in the last ragged dimension.
// - `index_suffix` is the index in the dense value dimensions.
std::vector<int64> index_prefix(rt_nested_splits_len);
std::vector<std::vector<int64>> index_suffixes =
MakeIndexSuffixes(rt_dense_values_in.shape());
// Allocate the `sparse_indices` output tensor.
const int64 nvals =
(rt_nested_splits.back()(rt_nested_splits.back().size() - 1) *
index_suffixes.size());
const int64 indices_len = rt_nested_splits_len + rt_dense_values_in.dims();
Tensor* sparse_indices_out = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape({nvals, indices_len}),
&sparse_indices_out));
auto sparse_indices = sparse_indices_out->tensor<int64, 2>();
// pos[i] is the current position in rt_nested_splits[i]. final_pos is a
// reference to make it easier to refer to pos[-1].
std::vector<int64> pos(rt_nested_splits_len);
int64& final_pos = pos[rt_nested_splits_len - 1];
// Each iteration through the loop, we increment pos[-1], and add indices
// for all the values corresponding to
// rt_nested_splits[-1][pos[-1]:pos[-1]+1].
int next_index = 0;
int max_final_pos = rt_nested_splits.back().size() - 1;
for (; final_pos < max_final_pos; ++final_pos) {
// Update `pos` to skip over completed elements (i.e., elements where
// we have already generated indices for all contained values).
for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) {
while (IsCompleted(pos, dim, rt_nested_splits)) {
pos[dim] += 1;
}
}
// Update index_prefix.
for (int dim = 0; dim < index_prefix.size(); ++dim) {
int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0;
index_prefix[dim] = pos[dim] - start;
}
// Get length of the final-ragged-dimension slice.
const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1];
int64 slice_len = final_splits(final_pos + 1) - final_splits(final_pos);
// Add sparse_indices for this slice.
for (int64 i = 0; i < slice_len; ++i) {
for (const auto& index_suffix : index_suffixes) {
int dim = 0;
for (int64 index : index_prefix) { // index_prefix
sparse_indices(next_index, dim++) = index;
}
sparse_indices(next_index, dim++) = i; // index_middle
for (int64 index : index_suffix) { // index_suffix
sparse_indices(next_index, dim++) = index;
}
DCHECK_EQ(dim, indices_len);
++next_index;
}
}
}
DCHECK_EQ(next_index, nvals);
// Output the `sparse_values` Tensor.
if (rt_dense_values_in.dims() == 1) {
context->set_output(1, rt_dense_values_in);
} else {
Tensor sparse_values_out(rt_dense_values_in.dtype());
bool shapes_match = sparse_values_out.CopyFrom(
rt_dense_values_in, {rt_dense_values_in.NumElements()});
DCHECK(shapes_match);
context->set_output(1, sparse_values_out);
}
// Output the `sparse_dense_shape` Tensor.
int64 ndims = rt_nested_splits_len + rt_dense_values_in.dims();
Tensor* sparse_dense_shape_out = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}),
&sparse_dense_shape_out));
auto sparse_dense_shape = sparse_dense_shape_out->vec<int64>();
sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
const auto& splits = rt_nested_splits[dim];
SPLITS_TYPE max_width = 0;
for (int i = 1; i < splits.size(); ++i) {
max_width = std::max(max_width, splits(i) - splits(i - 1));
}
sparse_dense_shape(dim + 1) = max_width;
}
for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) {
sparse_dense_shape(dim + rt_nested_splits_len) =
rt_dense_values_in.dim_size(dim);
}
}
private:
// Validate `rt_nested_splits` to ensure we don't get any segfaults.
static ::tensorflow::Status ValidateInputs(
std::vector<ConstFlatSplits> rt_nested_splits,
const Tensor& rt_dense_values_in) {
for (int i = 0; i < rt_nested_splits.size(); ++i) {
if (rt_nested_splits[i].size() == 0) {
return InvalidArgument("ragged splits may not be empty.");
}
if (rt_nested_splits[i](0) != 0) {
return InvalidArgument("First value of ragged splits must be 0.");
}
if (i > 0) {
SPLITS_TYPE last_split =
rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
if (rt_nested_splits[i].size() != last_split + 1) {
return InvalidArgument(
"Final value of ragged splits must match the length "
"the corresponding ragged values.");
}
}
}
if (rt_dense_values_in.dim_size(0) !=
rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) {
return InvalidArgument(
"Final value of ragged splits must match the length "
"the corresponding ragged values.");
}
return ::tensorflow::Status::OK();
}
// Build a list of index suffixes that should be added for each ragged item,
// to encode the indices of dense values in that ragged item. This basically
// just gives a row-major enumeration of all indices in the given tensor
// shape, ignoring dim[0] (since that's the dimension that iterates over
// values, and we want index suffixes for a single value). Example:
// MakeIndexSuffixes(TensorShape({100, 3, 2})
// --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}}
static std::vector<std::vector<int64>> MakeIndexSuffixes(
const TensorShape& values_shape) {
std::vector<std::vector<int64>> suffixes{{}};
for (int dim = 1; dim < values_shape.dims(); ++dim) {
std::vector<std::vector<int64>> new_suffixes;
for (const auto& suffix : suffixes) {
for (int i = 0; i < values_shape.dim_size(dim); ++i) {
new_suffixes.push_back(suffix);
new_suffixes.back().push_back(i);
}
}
suffixes.swap(new_suffixes);
}
return suffixes;
}
// Returns true if the ragged element at pos[dim] is "completed". A ragged
// element is completed if we have already generated indices for all of its
// values.
static bool IsCompleted(
const std::vector<int64>& pos, int dim,
const std::vector<ConstFlatSplits>& rt_nested_splits) {
int64 current_child = pos[dim + 1];
int64 limit_child = rt_nested_splits[dim](pos[dim] + 1);
return current_child >= limit_child;
}
};
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("Tsplits"),
RaggedTensorToSparseOp<int32>);
REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("Tsplits"),
RaggedTensorToSparseOp<int64>);
} // namespace tensorflow