-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
Sorting.cu
454 lines (393 loc) · 14 KB
/
Sorting.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/SortingUtils.h>
#include <c10/macros/Macros.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/SortingRadixSelect.cuh>
#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
#include <THC/THCNumerics.cuh>
#include <THC/THCScanUtils.cuh>
#include <THC/THCTensorMathReduce.cuh> // AddOp
#include <cassert>
#include <cstdlib>
namespace at {
namespace native {
namespace {
// Finds the rank k element, and its index, of the values along dimension dim
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherKthValue(
cuda::detail::TensorInfo<scalar_t, index_t> input,
index_t inputSliceSize,
index_t k,
index_t numInputSlices,
index_t inputWithinSliceStride,
cuda::detail::TensorInfo<scalar_t, index_t> kthValue,
cuda::detail::TensorInfo<int64_t, index_t> indices) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of index_t
__shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit
index_t slice = getLinearBlockId<index_t>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
index_t sliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
index_t kthValueSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, kthValue);
index_t indicesSliceStartIndex =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
scalar_t* inputSliceStart = &input.data[sliceStartIndex];
scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
scalar_t kValue = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t,
false>(
inputSliceStart,
k,
inputSliceSize,
inputWithinSliceStride,
smem,
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange &&
((v == kValue) ||
(THCNumerics<scalar_t>::isnan(v) &&
THCNumerics<scalar_t>::isnan(kValue)));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
}
}
// CUDA kernel to find the median, and its index, of the values along dimension dim
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherMedian(
cuda::detail::TensorInfo<scalar_t, index_t> values,
cuda::detail::TensorInfo<int64_t, index_t> indices,
cuda::detail::TensorInfo<scalar_t, index_t> input,
index_t inputSliceSize,
index_t numInputSlices,
index_t inputWithinSliceStride,
bool ignore_nan) {
// Shared memory for the subroutine RadixSelect. Note that RadixSelect converts the
// floating point type to int with the same relative ordering.
__shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit
index_t slice = getLinearBlockId<index_t>();
if (slice >= numInputSlices) {
return;
}
// Finds the start offset for our slice
index_t valuesSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, values);
index_t indicesSliceStartIndex =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
index_t inputSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
scalar_t* valuesSliceStart = &values.data[valuesSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
scalar_t* inputSliceStart = &input.data[inputSliceStartIndex];
index_t nan_count = 0;
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
nan_count += THCNumerics<scalar_t>::isnan(val) ? 1 : 0;
}
// Counts number of nan values
// This code performs a parallel sum reduction (not the most efficient code)
__shared__ int64_t num_nan;
if (threadIdx.x == 0) {
num_nan = 0;
}
__syncthreads();
if (nan_count > 0) {
atomicAdd(&num_nan, nan_count);
}
__syncthreads();
// For torch.median, if we found nan set k to last index so the computed value
// is nan, otherwise set k to the middle element of the non-nan values
index_t k = (!ignore_nan && num_nan > 0) ? inputSliceSize - 1
: (inputSliceSize - num_nan - 1) / 2;
// Find the median
scalar_t median = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t,
false>(
inputSliceStart,
k + 1,
inputSliceSize,
inputWithinSliceStride,
smem,
&median);
valuesSliceStart[0] = median;
// Find the index of the median value in the slice
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
if (val == median ||
(THCNumerics<scalar_t>::isnan(val) &&
THCNumerics<scalar_t>::isnan(median))) {
indicesSliceStart[0] = i;
break;
}
}
}
struct KthValueLauncher {
int64_t k;
KthValueLauncher(int64_t k) : k(k) {}
template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
dim3 grid;
if (!getGridFromTiles(num_slices, grid)) {
AT_ERROR("slices are too many");
}
dim3 block(std::min(
THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024));
auto stream = at::cuda::getCurrentCUDAStream();
gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
self_info,
slice_size,
k,
num_slices,
/* The actual dimension that the k-selection is running in */
/* may have changed from collapseDims() */
self_info.strides[collapse_self_dim],
values_info,
indices_info);
}
};
struct MedianLauncher {
bool ignore_nan;
MedianLauncher(bool ignore_nan) : ignore_nan(ignore_nan) {}
template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
dim3 grid;
if (!getGridFromTiles(num_slices, grid)) {
AT_ERROR("slices are too many");
}
dim3 block(std::min(
THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024));
auto stream = at::cuda::getCurrentCUDAStream();
gatherMedian<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
values_info,
indices_info,
self_info,
slice_size,
num_slices,
self_info.strides[collapse_self_dim],
ignore_nan);
}
};
template <typename scalar_t>
void kthvalue_cuda_template(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
// FIXME: This seems bogus, I only do this because it was the old behaviour.
// The reductions are fine, as long as the axis being reduced along
// isn't of 0 elements (and the output has elements).
TORCH_CHECK(
self.numel() > 0,
"cannot perform reduction function kthvalue",
" on tensor with no elements because the operation does not have an identity");
TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim, keepdim);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return;
}
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
// Based on required index size, run the algorithm with the
// appropriate index type
if (cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
cuda::detail::canUse32BitIndexMath(indices)) {
run_launcher<scalar_t, uint32_t>(
values, indices, self, dim, KthValueLauncher(k));
} else {
run_launcher<scalar_t, uint64_t>(
values, indices, self, dim, KthValueLauncher(k));
}
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim) {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] {
kthvalue_cuda_template<scalar_t>(
values, indices, self, k, dim, keepdim);
});
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor&, Tensor&> median_with_indices_impl(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim,
bool keepdim,
bool ignore_nan) {
NoNamesGuard guard;
dim = at::maybe_wrap_dim(dim, self.dim());
Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
int64_t size = in.size(dim);
TORCH_CHECK(
size > 0,
"median() cannot compute median for a dimension of size 0 because ",
"the operation does not have an identity");
checkDeviceType("median", {values, indices}, self.device().type());
checkScalarType("median", {indices, "indices", 1}, kLong);
checkSameType("median", {values, "values", 0}, {self, "self", 2});
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"median() cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
std::vector<int64_t> out_shape = self.sizes().vec();
if (self.dim() > 0) {
if (keepdim) {
out_shape[dim] = 1;
} else {
out_shape.erase(out_shape.begin() + dim);
}
}
values.resize_(out_shape);
indices.resize_(out_shape);
// Only launch kernel for non-empty tensors
if (self.numel() > 0) {
// Ensure #dim is the same for all tensors required for reduction
Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim);
Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim);
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, self.scalar_type(), "median_out_impl", [&] {
if (cuda::detail::canUse32BitIndexMath(vals) &&
cuda::detail::canUse32BitIndexMath(inds) &&
cuda::detail::canUse32BitIndexMath(in)) {
run_launcher<scalar_t, uint32_t>(
vals, inds, in, dim, MedianLauncher(ignore_nan));
} else {
run_launcher<scalar_t, uint64_t>(
vals, inds, in, dim, MedianLauncher(ignore_nan));
}
});
AT_CUDA_CHECK(cudaGetLastError());
}
guard.reset();
namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
return std::forward_as_tuple(values, indices);
}
Tensor median_impl(const Tensor& self, bool ignore_nan) {
NoNamesGuard guard;
int64_t size = self.numel();
TORCH_CHECK(size > 0, "median() input tensor cannot be empty");
// Sort input tensor to efficiently query for median element
Tensor sorted = std::get<0>(self.flatten().sort());
if (!ignore_nan) {
// For torch.median return either the middle element or nan (sorted as
// largest) if there are any
int64_t k = (size - 1) / 2;
return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]);
} else {
// For torch.nanmedian return the middle element among the non-nan values
Tensor k = ((size - 1) - sorted.isnan().sum()) / 2;
return sorted[k.toType(kLong)];
}
}
} // namespace
// Mark: kthvalue
std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim) {
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
return kthvalue_out_impl_cuda(values, indices, self.contiguous(), k, dim, keepdim);
}();
namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
return result;
}
// Mark: median
std::tuple<Tensor&, Tensor&> median_out_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim,
bool keepdim) {
return median_with_indices_impl(
values, indices, self, dim, keepdim, /*ignore_nan=*/false);
}
Tensor median_cuda(const Tensor& self) {
return median_impl(self, /*ignore_nan=*/false);
}
std::tuple<Tensor&, Tensor&> nanmedian_out_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim,
bool keepdim) {
return median_with_indices_impl(
values, indices, self, dim, keepdim, /*ignore_nan=*/true);
}
Tensor nanmedian_cuda(const Tensor& self) {
return median_impl(self, /*ignore_nan=*/true);
}
} // namespace native
} // namespace at