/
zero_out_op.cpp
45 lines (36 loc) · 1.38 KB
/
zero_out_op.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// This code is from the guide in https://www.tensorflow.org/guide/extend/op.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
/* Describe the kernel. */
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction *context) : OpKernel(context) {}
void Compute(OpKernelContext *context) override {
// Grab the input tensor
const Tensor &input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor *output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0)
output_flat(0) = input(0);
}
};
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
c->set_output(0, c->input(0));
return Status::OK();
});
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);