-
Notifications
You must be signed in to change notification settings - Fork 74k
/
data_format_ops.cc
346 lines (309 loc) · 14.3 KB
/
data_format_ops.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/nn_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/data_format_ops.h"
#include <map>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
// Ensure that `src` and `dst` define a valid permutation.
// Ops defined in this file assume that user specifies a permutation via two
// string attributes. This check validates that these attributes properly define
// it to prevent security vulnerabilities.
static bool IsValidPermutation(const std::string& src, const std::string& dst) {
if (src.size() != dst.size()) {
return false;
}
std::map<char, bool> characters;
// Every character in `src` must be present only once
for (const auto c : src) {
if (characters[c]) {
return false;
}
characters[c] = true;
}
// Every character in `dst` must show up in `src` exactly once
for (const auto c : dst) {
if (!characters[c]) {
return false;
}
characters[c] = false;
}
// At this point, characters[] has been switched to true and false exactly
// once for all character in `src` (and `dst`) so we have a valid permutation
return true;
}
template <typename Device, typename T>
class DataFormatDimMapOp : public OpKernel {
public:
explicit DataFormatDimMapOp(OpKernelConstruction* context)
: OpKernel(context) {
string src_format;
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
string dst_format;
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
errors::InvalidArgument(
"Source format must be of length 4 or 5, received "
"src_format = ",
src_format));
OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
errors::InvalidArgument("Destination format must be of length "
"4 or 5, received dst_format = ",
dst_format));
OP_REQUIRES(
context, IsValidPermutation(src_format, dst_format),
errors::InvalidArgument(
"Destination and source format must determine a permutation, got ",
src_format, " and ", dst_format));
dst_idx_ = Tensor(DT_INT32, {static_cast<int64_t>(src_format.size())});
for (int i = 0; i < src_format.size(); ++i) {
for (int j = 0; j < dst_format.size(); ++j) {
if (dst_format[j] == src_format[i]) {
dst_idx_.vec<int>()(i) = j;
break;
}
}
}
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
Tensor* output;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
input.flat<T>(), output->flat<T>(),
dst_idx_.vec<int>());
}
Tensor dst_idx_;
};
template <typename Device, typename T>
class DataFormatVecPermuteOp : public OpKernel {
public:
explicit DataFormatVecPermuteOp(OpKernelConstruction* context)
: OpKernel(context) {
string src_format;
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
errors::InvalidArgument(
"Source format must be of length 4 or 5, received "
"src_format = ",
src_format));
string dst_format;
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
errors::InvalidArgument("Destination format must be of length "
"4 or 5, received dst_format = ",
dst_format));
OP_REQUIRES(
context, IsValidPermutation(src_format, dst_format),
errors::InvalidArgument(
"Destination and source format must determine a permutation, got ",
src_format, " and ", dst_format));
src_format_ = src_format;
dst_format_ = dst_format;
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2,
errors::InvalidArgument(
"input must be a vector or 2D tensor, but got shape ",
input.shape().DebugString()));
const int full_dim_count = src_format_.size();
const int spatial_dim_count = full_dim_count - 2;
if (input.dims() == 1) {
OP_REQUIRES(context,
input.NumElements() == spatial_dim_count ||
input.NumElements() == full_dim_count,
errors::InvalidArgument("1D input must be of size ",
spatial_dim_count, " or ",
full_dim_count, ", but got shape ",
input.shape().DebugString()));
} else if (input.dims() == 2) {
OP_REQUIRES(context,
input.dim_size(0) == spatial_dim_count ||
input.dim_size(0) == full_dim_count,
errors::InvalidArgument("First dimension of 2D input must be "
"of size ",
spatial_dim_count, " or ",
full_dim_count, ", but got shape ",
input.shape().DebugString()));
OP_REQUIRES(
context, input.dim_size(1) == 2,
errors::InvalidArgument(
"Second dimension of 2D input must be of size 2, but got shape ",
input.shape().DebugString()));
}
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
// Support 1D and 2D cases.
Eigen::DSizes<Eigen::DenseIndex, 10> dst_idx;
string src_format_str = src_format_;
string dst_format_str = dst_format_;
if (input.dim_size(0) == spatial_dim_count) {
// If the input is a vector of size spatial_dim_count, treat the elements
// as spatial dimensions.
auto keep_only_spatial_dimensions =
[spatial_dim_count](string* format_str) -> void {
auto new_end =
std::remove_if(format_str->begin(), format_str->end(),
[spatial_dim_count](const char dim) {
return dim != 'H' && dim != 'W' &&
(spatial_dim_count == 2 || dim != 'D');
});
format_str->erase(new_end, format_str->end());
};
keep_only_spatial_dimensions(&src_format_str);
keep_only_spatial_dimensions(&dst_format_str);
if (spatial_dim_count == 3) {
OP_REQUIRES(
context, src_format_str.size() == 3 && dst_format_str.size() == 3,
errors::InvalidArgument(
"Format specifier must contain D, H and W for 2D case"));
} else {
DCHECK(spatial_dim_count == 2);
OP_REQUIRES(context,
src_format_str.size() == 2 && dst_format_str.size() == 2,
errors::InvalidArgument(
"Format specifier must contain H and W for 2D case"));
}
}
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
input.flat<T>(),
output->flat<T>(), dst_idx);
}
private:
// Finds out the destination index. Support 1D and 2D cases.
// Example: HWNC --> NHWC
// 1D: dst = [1, 2, 0, 3],
// 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
static void ComputeDstIndex(const string& src_format_str,
const string& dst_format_str, int num_dim,
Eigen::DSizes<Eigen::DenseIndex, 10>* dst) {
for (int i = 0; i < src_format_str.size(); ++i) {
for (int j = 0; j < dst_format_str.size(); ++j) {
if (dst_format_str[j] != src_format_str[i]) continue;
// Found the dst index. Set output based on the number of dims.
for (int k = 0; k < num_dim; ++k) {
(*dst)[i * num_dim + k] = j * num_dim + k;
}
}
}
}
string src_format_;
string dst_format_;
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
DataFormatDimMapOp<CPUDevice, T>);
TF_CALL_int32(REGISTER_KERNEL);
TF_CALL_int64(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
DataFormatVecPermuteOp<CPUDevice, T>);
TF_CALL_int32(REGISTER_KERNEL);
TF_CALL_int64(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DataFormatDimMap") \
.Device(DEVICE_CPU) \
.Label("host") \
.TypeConstraint<T>("T"), \
DataFormatDimMapOp<CPUDevice, T>);
TF_CALL_int32(REGISTER_KERNEL);
TF_CALL_int64(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \
.Device(DEVICE_CPU) \
.Label("host") \
.TypeConstraint<T>("T"), \
DataFormatVecPermuteOp<CPUDevice, T>);
TF_CALL_int32(REGISTER_KERNEL);
TF_CALL_int64(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void DataFormatDimMap<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \
extern template struct DataFormatDimMap<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
TF_CALL_int64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPEC
#define DECLARE_GPU_SPEC(T) \
template <> \
void DataFormatVecPermute<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
typename TTypes<T>::Vec y, \
const Eigen::DSizes<Eigen::DenseIndex, 10>& dst_idx); \
extern template struct DataFormatVecPermute<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
TF_CALL_int64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
DataFormatDimMapOp<GPUDevice, T>);
TF_CALL_int32(REGISTER_GPU_KERNEL);
TF_CALL_int64(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
DataFormatVecPermuteOp<GPUDevice, T>);
TF_CALL_int32(REGISTER_GPU_KERNEL);
TF_CALL_int64(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Registration of the DEVICE_DEFAULT implementations.
#define REGISTER_DEVICE_DEFAULT_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DataFormatDimMap") \
.Device(DEVICE_DEFAULT) \
.HostMemory("x") \
.HostMemory("y") \
.Label("host") \
.TypeConstraint<T>("T"), \
DataFormatDimMapOp<CPUDevice, T>);
TF_CALL_int32(REGISTER_DEVICE_DEFAULT_KERNEL);
TF_CALL_int64(REGISTER_DEVICE_DEFAULT_KERNEL);
#undef REGISTER_DEVICE_DEFAULT_KERNEL
#define REGISTER_DEVICE_DEFAULT_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \
.Device(DEVICE_DEFAULT) \
.HostMemory("x") \
.HostMemory("y") \
.Label("host") \
.TypeConstraint<T>("T"), \
DataFormatVecPermuteOp<CPUDevice, T>);
TF_CALL_int32(REGISTER_DEVICE_DEFAULT_KERNEL);
TF_CALL_int64(REGISTER_DEVICE_DEFAULT_KERNEL);
#undef REGISTER_DEVICE_DEFAULT_KERNEL
} // namespace tensorflow