-
Notifications
You must be signed in to change notification settings - Fork 74k
/
in_topk_op_gpu.cu.cc
182 lines (149 loc) · 6.76 KB
/
in_topk_op_gpu.cu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/in_topk_op.h"
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Compare each prediction in 'predictions' with a target prediction for the
// batch, and write result to the 'mask':
// -1: If the target class is out of range, or if the prediction value is not
// finite and can't be compared to target prediction (and vice versa).
// 0: If prediction is smaller than the target prediction for the batch.
// 1: If prediction is larger than the target prediction for the batch.
template <typename T, typename TargetT>
__global__ void ComputePredictionMaskKernel(
const T* __restrict__ predictions, // dims: [ num_targets x num_classes ]
const TargetT* __restrict__ targets, // dims: [ num_targets ]
int64* __restrict__ mask, // dims: [ num_targets x num_classes ]
int num_targets, int num_classes) {
GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
const int batch_index = i / num_classes;
TargetT target_idx = ldg(targets + batch_index);
if (!FastBoundsCheck(target_idx, num_classes)) {
mask[i] = -1;
return;
}
T prediction = ldg(predictions + i);
T target_prediction =
ldg(predictions + batch_index * num_classes + target_idx);
if (!Eigen::numext::isfinite(prediction) ||
!Eigen::numext::isfinite(target_prediction)) {
mask[i] = -1;
} else {
mask[i] = prediction > target_prediction ? 1 : 0;
}
}
}
// Reduce all prediction masks either to the sum of '1' for each prediction
// larger than the target, or to '-1' if target class in invalid of predictions
// in a batch have non-finite values.
struct MaskSum {
__host__ __device__ int64 operator()(const int64& a, const int64& b) const {
if (a < 0 || b < 0)
return -1;
else
return a + b;
}
};
namespace reduction_op_helper {
template <>
struct IdentityValue<int64, MaskSum> {
int64 operator()() { return 0; }
};
} // namespace reduction_op_helper
template <typename T, typename TargetT>
struct InTopKFunctor<GPUDevice, T, TargetT> {
template <int ndims>
using Dims = Eigen::DSizes<Eigen::Index, ndims>;
void operator()(OpKernelContext* context,
typename TTypes<T, 2>::ConstTensor predictions,
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
typename TTypes<bool>::Vec output) {
const Eigen::Index num_targets = predictions.dimension(0);
const Eigen::Index num_classes = predictions.dimension(1);
OP_REQUIRES(
context, num_targets * num_classes < std::numeric_limits<int>::max(),
errors::InvalidArgument(
"Number of targets * number of classes must be less than INT_MAX"));
if (num_targets == 0 || num_classes == 0) {
// Result is empty, so shortcut the rest of the function to avoid
// launching kernels with empty input.
return;
}
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
Tensor predictions_mask;
OP_REQUIRES_OK(
context, context->allocate_temp(DT_INT64,
TensorShape({num_targets, num_classes}),
&predictions_mask));
// Number of predictions for each target that are larger than the target
// prediction (or -1 if we can't compute this number, because not all
// predictions are finite or target class is out of range).
Tensor num_larger_prediction;
OP_REQUIRES_OK(context,
context->allocate_temp(DT_INT64, TensorShape({num_targets}),
&num_larger_prediction));
const auto& d = context->eigen_device<GPUDevice>();
// Compute a mask for all predictions.
GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
OP_REQUIRES_OK(
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
config.block_count, config.thread_per_block, 0,
d.stream(), predictions.data(), targets.data(),
predictions_mask.flat<int64_t>().data(),
num_targets, num_classes));
// Reduce prediction masks to number of predictions larger than the target
// prediction, or to the negative value if we can't compute an answer.
{
auto in = predictions_mask.matrix<int64_t>();
auto out = num_larger_prediction.flat<int64_t>();
ReduceImpl<int64, MaskSum, int64*, int64*, Dims<1>>(
context, (int64*)out.data(), (int64*)in.data(), in.rank(),
in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1,
in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), Dims<1>(1),
MaskSum());
}
// Compute if target prediction is in top K predictions.
auto cnt = num_larger_prediction.flat<int64_t>();
if (k.k_tensor != nullptr) {
if (k.k_tensor->dtype() == DT_INT32) {
output.device(d) =
(cnt >= cnt.constant(0)) &&
(cnt < k.k_tensor->flat<int32>().template cast<int64_t>().broadcast(
Dims<1>(num_targets)));
} else {
output.device(d) =
(cnt >= cnt.constant(0)) &&
(cnt < k.k_tensor->flat<int64_t>().broadcast(Dims<1>(num_targets)));
}
} else {
output.device(d) =
(cnt >= cnt.constant(0)) && (cnt < targets.constant(k.k_value));
}
}
};
} // namespace functor
// Definition of the GPU implementations declared in in_topk_op.cc.
#define DEFINE_GPU_KERNELS(T, TARGET_T) \
template struct functor::InTopKFunctor<GPUDevice, T, TARGET_T>;
DEFINE_GPU_KERNELS(float, int32);
DEFINE_GPU_KERNELS(float, int64);
#undef DEFINE_GPU_KERNELS
} // end namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM