-
Notifications
You must be signed in to change notification settings - Fork 501
/
sparse_ops_cpu.cpp
3016 lines (2780 loc) · 116 KB
/
sparse_ops_cpu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <cmath>
#include <functional>
#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <torch/library.h>
#include "ATen/Parallel.h"
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/autograd/custom_function.h>
#include "c10/util/MaybeOwned.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
namespace {
// To avoid multiple threads are touching the same cache line.
// Assume cache line size is 64B and element size is at least 4B like float or
// int32.
constexpr int FALSE_SHARING_PAD = 16;
// Converts sparse tensor to dense tensor with few optimizations to be used with
// histogram binning calibration by feature. (1) Assumes dense_last_dim == 1 (2)
// Does not update default value when length > 1. HBC by feature has a separate
// logic to handle this, but we fold it over here.
template <typename SegmentValueType, typename SegmentLengthType>
void _to_dense_representation(
const int64_t num_lengths,
const SegmentValueType* const segment_value_data,
const SegmentLengthType* const segment_lengths_data,
SegmentValueType* const dense_segment_value_data) {
int k = 0;
for (const auto i : c10::irange(num_lengths)) {
if (segment_lengths_data[i] == 1) {
// Add 1 to distinguish between 0 inserted by densification vs. original
// value.
dense_segment_value_data[i] = segment_value_data[k] + 1;
} else {
dense_segment_value_data[i] = 0;
}
k += segment_lengths_data[i];
}
}
} // namespace
using Tensor = at::Tensor;
namespace fbgemm_gpu {
// Custom PackSegments operator that is based on the Caffe2 PackSegments and
// UnpackSegments.
// Needed this to support backward pass.
class PackSegments : public torch::autograd::Function<PackSegments> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& t_in,
const Tensor& lengths,
at::SymInt max_length) {
const at::SymInt total_length = t_in.sym_size(0);
at::AutoDispatchBelowADInplaceOrView guard;
static auto custom_pack_segments_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments", "")
.typed<at::Tensor(
const at::Tensor&, const at::Tensor&, const at::SymInt)>();
Tensor res = custom_pack_segments_op.call(t_in, lengths, max_length);
ctx->saved_data["max_length"] = max_length;
ctx->saved_data["total_length"] = total_length;
ctx->save_for_backward({lengths});
return {res};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(grad_output.size() == 2 or grad_output.size() == 1);
const Tensor& grad = grad_output[0];
const auto& max_length = ctx->saved_data["max_length"].toSymInt();
const auto& total_length = ctx->saved_data["total_length"].toSymInt();
// Retrieve saved variables for backward.
const auto& saved_variables = ctx->get_saved_variables();
const auto& lengths = saved_variables[0];
torch::autograd::variable_list grad_inputs(5);
static auto custom_pack_segments_backward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments_backward", "")
.typed<at::Tensor(
const at::Tensor&,
const at::Tensor&,
const at::SymInt,
const at::SymInt)>();
grad_inputs[0] = custom_pack_segments_backward_op.call(
grad, lengths, total_length, max_length);
return grad_inputs;
}
};
Tensor pack_segments_autograd(
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt max_length
) {
return PackSegments::apply(t_in, lengths, max_length)[0];
}
Tensor native_empty_like(const Tensor& self) {
return at::native::empty_like(
self,
c10::optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt(),
c10::nullopt);
}
template <typename T>
void prefix_sum(const int length, const T* const array, T* const presum) {
presum[0] = 0;
for (const auto i : c10::irange(length)) {
presum[i + 1] = array[i] + presum[i];
}
}
// NOTE : _permute_indices_weights_kernel_cpu and _permute_lengths_cpu_kernel
// have to use the same grain size for consistent partitioning across threads.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
void _permute_2D_indices_weights_kernel_cpu(
const int32_t T,
const int32_t B,
const indices_t* const __restrict__ indices,
const weights_t* const __restrict__ weights,
const int32_t* const __restrict__ permute,
const offsets_t* const __restrict__ input_offsets,
const int64_t* const __restrict__ output_offsets_per_thread_cumsum,
indices_t* const __restrict__ permuted_indices,
weights_t* const __restrict__ permuted_weights,
const offsets_t* const __restrict__ permuted_lengths) {
at::parallel_for(
0, T * B, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
offsets_t output_start = output_offsets_per_thread_cumsum
[at::get_thread_num() * FALSE_SHARING_PAD];
int64_t t_begin = tb_begin / B;
int64_t t_end = (tb_end + B - 1) / B;
for (const auto t : c10::irange(t_begin, t_end)) {
int64_t b_begin = (t == t_begin) ? tb_begin % B : 0;
int64_t b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;
for (const auto b : c10::irange(b_begin, b_end)) {
offsets_t permuted_length = permuted_lengths[t * B + b];
const offsets_t input_start = input_offsets[permute[t] * B + b];
for (const auto i : c10::irange(permuted_length)) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
output_start += permuted_length;
} // for each b
} // for each t
}); // parallel_for T * B
}
template <typename index_t>
void _permute_2D_lengths_cpu_kernel(
const int32_t T,
const int32_t B,
const index_t* const __restrict__ lengths,
int64_t lengths_size,
const int32_t* const __restrict__ permute,
index_t* const __restrict__ permuted_lengths,
index_t* const __restrict__ input_offsets,
int64_t* const __restrict__ output_offsets_per_thread_cumsum) {
int num_threads = at::get_num_threads();
std::vector<int> input_offsets_per_thread_cumsum(
(num_threads + 1) * FALSE_SHARING_PAD, 0);
// First parallel for: populate permuted_lengths, and compute per-thread
// summation of lengths (input_offsets_per_thread_cumsum) and permuted_lengths
// (output_offsets_per_thread_cumsum)
at::parallel_for(
0, T * B, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
index_t current_input_offset = 0;
// Have a separate loop for summing up lengths because lengths_size
// can be smaller than T * B.
for (int tb = tb_begin; tb < std::min(tb_end, lengths_size); ++tb) {
current_input_offset += lengths[tb];
}
index_t current_output_offset = 0;
int64_t t_begin = tb_begin / B;
int64_t t_end = (tb_end + B - 1) / B;
for (const auto t : c10::irange(t_begin, t_end)) {
int64_t b_begin = (t == t_begin) ? tb_begin % B : 0;
int64_t b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;
for (const auto b : c10::irange(b_begin, b_end)) {
auto permuted_length = lengths[permute[t] * B + b];
permuted_lengths[t * B + b] = permuted_length;
current_output_offset += permuted_length;
}
}
input_offsets_per_thread_cumsum
[(at::get_thread_num() + 1) * FALSE_SHARING_PAD] =
current_input_offset;
output_offsets_per_thread_cumsum
[(at::get_thread_num() + 1) * FALSE_SHARING_PAD] =
current_output_offset;
});
// Inter-thread reduction
for (const auto t : c10::irange(1, num_threads)) {
input_offsets_per_thread_cumsum[(t + 1) * FALSE_SHARING_PAD] +=
input_offsets_per_thread_cumsum[t * FALSE_SHARING_PAD];
output_offsets_per_thread_cumsum[(t + 1) * FALSE_SHARING_PAD] +=
output_offsets_per_thread_cumsum[t * FALSE_SHARING_PAD];
}
// Second parallel for: populate input_offsets
// NOTE: this works assuming the partitioning will be the same as the
// first parallel_for.
at::parallel_for(
0, T * B, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
index_t current_input_offset = input_offsets_per_thread_cumsum
[at::get_thread_num() * FALSE_SHARING_PAD];
if (tb_begin < lengths_size) {
input_offsets[tb_begin] = current_input_offset;
}
for (const auto tb :
c10::irange(tb_begin, std::min(tb_end - 1, lengths_size))) {
current_input_offset += lengths[tb];
input_offsets[tb + 1] = current_input_offset;
}
});
if (lengths_size >= T * B) {
input_offsets[T * B] =
input_offsets_per_thread_cumsum[num_threads * FALSE_SHARING_PAD];
}
// Handle cases when lengths_size > T * B
for (const auto i : c10::irange(T * B, lengths_size)) {
input_offsets[i + 1] = lengths[i] + input_offsets[i];
}
}
template <
bool sequence,
bool has_weight,
typename offset_t,
typename index_t,
typename scalar_t>
void _block_bucketize_sparse_features_cpu(
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const bool bucketize_pos,
const Tensor& block_sizes,
const int64_t my_size,
Tensor new_lengths,
Tensor new_indices,
c10::optional<Tensor> new_weights,
c10::optional<Tensor> new_pos,
const c10::optional<Tensor>& unbucketize_permute,
const c10::optional<Tensor>& batch_size_per_feature,
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
// allocate tensors and buffers
const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
const int32_t T = block_sizes.numel();
const int32_t B = lengths_size / T;
auto offsets = at::empty({lengths_size + 1}, lengths.options());
auto new_offsets = at::empty({new_lengths_size + 1}, lengths.options());
const offset_t* lengths_data = lengths.data_ptr<offset_t>();
offset_t* offsets_data = offsets.data_ptr<offset_t>();
const index_t* indices_data = indices.data_ptr<index_t>();
scalar_t* weights_data = nullptr;
scalar_t* new_weights_data = nullptr;
index_t* new_pos_data = nullptr;
index_t* unbucketize_permute_data = nullptr;
offset_t* const new_lengths_data = new_lengths.data_ptr<offset_t>();
offset_t* const new_offsets_data = new_offsets.data_ptr<offset_t>();
index_t* const new_indices_data = new_indices.data_ptr<index_t>();
const index_t* const block_sizes_data = block_sizes.data_ptr<index_t>();
offset_t* batch_sizes_data = nullptr;
const auto variable_batch_size = batch_size_per_feature.has_value();
const auto variable_bucket_sizes = block_bucketize_pos.has_value() &&
block_bucketize_pos.value().size() != 0;
using uindex_t = std::make_unsigned_t<index_t>;
using uoffset_t = std::make_unsigned_t<offset_t>;
std::vector<int64_t> lower_bounds(indices.numel(), 0);
if constexpr (sequence) {
unbucketize_permute_data = unbucketize_permute.value().data_ptr<index_t>();
}
if constexpr (has_weight) {
weights_data = weights.value().data_ptr<scalar_t>();
new_weights_data = new_weights.value().data_ptr<scalar_t>();
}
if (bucketize_pos) {
new_pos_data = new_pos.value().data_ptr<index_t>();
}
if (variable_batch_size) {
batch_sizes_data = batch_size_per_feature.value().data_ptr<offset_t>();
}
// count nonzeros
prefix_sum(lengths_size, lengths_data, offsets_data);
assert(offsets_data[lengths_size] == indices.numel());
int64_t cur_offset = 0;
for (const auto t : c10::irange(T)) {
const auto blk_size = block_sizes_data[t];
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
const index_t* bucketize_offset = nullptr;
int64_t bucket_size = 0;
if (variable_bucket_sizes) {
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
bucket_size = block_bucketize_pos.value()[t].numel();
}
for (const auto b : c10::irange(cur_batch_size)) {
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
const offset_t rowstart = offsets_data[b_t];
const offset_t rowend = offsets_data[b_t + 1];
for (const auto i : c10::irange(rowstart, rowend)) {
// We have use cases using none-hashed raw indices that can be either
// negative or larger than embedding table hash_size (blk_size *
// my_size). In cases of none-hashed indices we need to ensure
// bucketization can distribute them into different ranks and within
// range of blk_size, we expect the later embedding module to take care
// of hashing indices calculation.
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
if (variable_bucket_sizes) {
int64_t lb = std::upper_bound(
bucketize_offset,
bucketize_offset + static_cast<index_t>(bucket_size),
indices_data[i]) -
bucketize_offset - 1;
lower_bounds[i] = lb;
uindex_t p = lb < my_size ? lb : idx % my_size;
new_lengths_data[p * lengths_size + b_t]++;
} else {
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
? idx / blk_size
: idx % my_size;
new_lengths_data[p * lengths_size + b_t]++;
}
}
}
cur_offset += cur_batch_size;
}
// bucketize nonzeros
prefix_sum(new_lengths_size, new_lengths_data, new_offsets_data);
assert(new_offsets_data[new_lengths_size] == new_indices.numel());
cur_offset = 0;
for (const auto t : c10::irange(T)) {
const auto blk_size = block_sizes_data[t];
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
const index_t* bucketize_offset = nullptr;
if (variable_bucket_sizes) {
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
}
for (const auto b : c10::irange(cur_batch_size)) {
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
const offset_t rowstart = offsets_data[b_t];
const offset_t rowend = offsets_data[b_t + 1];
for (const auto i : c10::irange(rowstart, rowend)) {
// We have use cases using none-hashed raw indices that can be either
// negative or larger than embedding table hash_size (blk_size *
// my_size). In cases of none-hashed indices we need to ensure
// bucketization can distribute them into different ranks and within
// range of blk_size, we expect the later embedding module to take care
// of hashing indices calculation.
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
uindex_t p, new_idx;
if (variable_bucket_sizes) {
int64_t lb = lower_bounds[i];
p = lb < my_size ? lb : idx % my_size;
new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
} else {
p = idx < static_cast<uindex_t>(blk_size * my_size) ? idx / blk_size
: idx % my_size;
new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
? idx % blk_size
: idx / my_size;
}
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
new_indices_data[pos] = new_idx;
if (sequence) {
unbucketize_permute_data[i] = pos;
}
new_offsets_data[p * lengths_size + b_t]++;
if (has_weight) {
new_weights_data[pos] = weights_data[i];
}
if (bucketize_pos) {
new_pos_data[pos] = i - rowstart;
}
}
}
cur_offset += cur_batch_size;
}
}
void FloatToBFloat16Quantized_ref(
const float* const input,
const size_t numel,
uint16_t* const output) {
for (const auto idx : c10::irange(numel)) {
const float* input_elem = input + idx;
uint16_t* output_elem = output + idx;
*output_elem =
(*reinterpret_cast<const uint32_t*>(input_elem) + (1 << 15)) >> 16;
}
}
void BFloat16QuantizedToFloat_ref(
const at::BFloat16* const input,
const size_t numel,
float* const output) {
for (const auto idx : c10::irange(numel)) {
const at::BFloat16* input_elem = input + idx;
float* output_elem = output + idx;
uint32_t val_fp32 =
static_cast<uint32_t>(*reinterpret_cast<const uint16_t*>(input_elem))
<< 16;
*reinterpret_cast<uint32_t*>(output_elem) = val_fp32;
}
}
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia NCCL
at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
TENSOR_ON_CPU(input);
const auto input_sizes = input.sizes();
auto output = at::empty(
input_sizes,
input.options().dtype(at::kHalf)); // at::kHalf
FloatToBFloat16Quantized_ref(
input.data_ptr<float>(),
input.numel(),
reinterpret_cast<uint16_t*>(output.data_ptr<at::Half>()));
return output;
}
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia NCCL
at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
TENSOR_ON_CPU(input);
const auto input_sizes = input.sizes();
auto output = at::empty(input_sizes, input.options().dtype(at::kFloat));
BFloat16QuantizedToFloat_ref(
reinterpret_cast<at::BFloat16*>(input.data_ptr<at::Half>()),
input.numel(),
output.data_ptr<float>());
return output;
}
// This function partitions sparse features
// cyclically along the sparse dimension into my_size blocks
template <bool has_weight, typename index_t, typename scalar_t>
void _bucketize_sparse_features_cpu(
const at::Tensor& lengths,
const at::Tensor& indices,
const c10::optional<at::Tensor>& weights,
const bool bucketize_pos,
const int64_t my_size,
at::Tensor& new_lengths,
at::Tensor& new_indices,
c10::optional<at::Tensor> new_weights,
c10::optional<at::Tensor> new_pos) {
TENSOR_ON_CPU(lengths);
TENSOR_ON_CPU(indices);
TENSOR_EMPTY_OR_ON_CPU(weights);
TENSOR_ON_CPU(new_lengths);
TENSOR_ON_CPU(new_indices);
TENSOR_EMPTY_OR_ON_CPU(new_weights);
TENSOR_EMPTY_OR_ON_CPU(new_pos);
using uindex_t = std::make_unsigned_t<index_t>;
// allocate tensors and buffers
const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
auto offsets = at::empty({lengths_size + 1}, lengths.options());
auto new_offsets = at::empty({new_lengths_size + 1}, lengths.options());
const index_t* lengths_data = lengths.data_ptr<index_t>();
index_t* offsets_data = offsets.data_ptr<index_t>();
const index_t* indices_data = indices.data_ptr<index_t>();
scalar_t* weights_data;
scalar_t* new_weights_data;
index_t* new_pos_data;
index_t* const new_lengths_data = new_lengths.data_ptr<index_t>();
index_t* const new_offsets_data = new_offsets.data_ptr<index_t>();
index_t* const new_indices_data = new_indices.data_ptr<index_t>();
if (has_weight) {
weights_data = weights.value().data_ptr<scalar_t>();
new_weights_data = new_weights.value().data_ptr<scalar_t>();
}
if (bucketize_pos) {
new_pos_data = new_pos.value().data_ptr<index_t>();
}
// count nonzeros
prefix_sum(lengths_size, lengths_data, offsets_data);
assert(offsets_data[lengths_size] == indices.numel());
for (const auto r : c10::irange(lengths_size)) {
const index_t rowstart = offsets_data[r];
const index_t rowend = offsets_data[r + 1];
for (const auto i : c10::irange(rowstart, rowend)) {
// Need to handle negative indices if we use raw idices instead of hashed
// indices, convert to unsigned
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
const uindex_t p = idx % my_size;
new_lengths_data[p * lengths_size + r]++;
}
}
// bucketize nonzeros
prefix_sum(new_lengths_size, new_lengths_data, new_offsets_data);
assert(new_offsets_data[new_lengths_size] == new_indices.numel());
for (const auto r : c10::irange(lengths_size)) {
const index_t rowstart = offsets_data[r];
const index_t rowend = offsets_data[r + 1];
for (const auto i : c10::irange(rowstart, rowend)) {
// Need to handle negative indices if we use raw idices instead of hashed
// indices, convert to unsigned
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
const uindex_t p = idx % my_size;
const uindex_t new_idx = idx / my_size;
const uindex_t pos = new_offsets_data[p * lengths_size + r];
new_indices_data[pos] = new_idx;
new_offsets_data[p * lengths_size + r]++;
if (has_weight) {
new_weights_data[pos] = weights_data[i];
}
if (bucketize_pos) {
new_pos_data[pos] = i - rowstart;
}
}
}
}
std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_2D_sparse_data_cpu(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum) {
TENSOR_ON_CPU(permute);
TENSOR_ON_CPU(lengths);
TENSOR_ON_CPU(indices);
if (weights) {
TENSOR_ON_CPU(weights);
}
TORCH_CHECK(lengths.dim() == 2);
const auto permute_contig = permute.expect_contiguous();
const auto lengths_contig = lengths.expect_contiguous();
const auto indices_contig = indices.expect_contiguous();
// the data to permute over can be less or more with or without
// repetitions
const auto T = permute.numel();
const auto B = lengths.size(1);
Tensor permuted_lengths;
Tensor permuted_indices;
c10::optional<Tensor> permuted_weights;
permuted_lengths = at::empty({T, B}, lengths.options());
const auto lengths_size = lengths.numel();
auto input_offsets = at::empty({lengths_size + 1}, lengths.options());
int num_threads = at::get_num_threads();
std::vector<int64_t> output_offsets_per_thread_cumsum(
(num_threads + 1) * FALSE_SHARING_PAD, 0);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_cpu_kernel", [&] {
_permute_2D_lengths_cpu_kernel(
T,
B,
lengths_contig->data_ptr<index_t>(),
lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>(),
input_offsets.data_ptr<index_t>(),
output_offsets_per_thread_cumsum.data());
}); // for each scalar_t
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();
} else {
permuted_indices_size =
output_offsets_per_thread_cumsum[num_threads * FALSE_SHARING_PAD];
}
permuted_indices = at::empty(permuted_indices_size, indices.options());
AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_2D_indices_weights_kernel_1", [&] {
using offsets_t = index_t;
AT_DISPATCH_ALL_TYPES(
indices.scalar_type(), "permute_2D_indices_weights_kernel_2", [&] {
using indices_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES(
weights.has_value() ? weights.value().scalar_type()
: at::ScalarType::Float,
"permute_2D_indices_weights_kernel_3",
[&] {
using weights_t = scalar_t;
if (weights.has_value()) {
const auto weights_value_contig =
weights.value().expect_contiguous();
permuted_weights = at::empty(
permuted_indices_size, weights.value().options());
_permute_2D_indices_weights_kernel_cpu<
true,
index_t,
indices_t,
weights_t>(
T,
B,
indices_contig->data_ptr<indices_t>(),
weights_value_contig->data_ptr<weights_t>(),
permute_contig->data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets_per_thread_cumsum.data(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights->data_ptr<weights_t>(),
permuted_lengths.data_ptr<offsets_t>());
} else {
_permute_2D_indices_weights_kernel_cpu<
false,
index_t,
indices_t,
weights_t>(
T,
B,
indices_contig->data_ptr<indices_t>(),
nullptr,
permute_contig->data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets_per_thread_cumsum.data(),
permuted_indices.data_ptr<indices_t>(),
nullptr,
permuted_lengths.data_ptr<offsets_t>());
}
}); // for each weights_t
}); // for each indices_t
}); // for each offsets_t
return {permuted_lengths, permuted_indices, permuted_weights};
}
// specialization for variable B and T,
// the permute here maps to all items in length.
template <typename index_t>
void _permute_1D_lengths_cpu_kernel(
const index_t* const __restrict__ lengths,
int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
index_t* const __restrict__ permuted_lengths) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
// Have a separate loop for summing up lengths
index_t current_output_offset = 0;
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
auto permuted_length = lengths[permute[tb]];
permuted_lengths[tb] = permuted_length;
current_output_offset += permuted_length;
}
});
}
// specialization for variable B and T,
// the permute here maps to all items in length.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
void _permute_1D_indices_weights_kernel_cpu(
const offsets_t* const __restrict__ input_offsets,
const indices_t* const __restrict__ indices,
const weights_t* const __restrict__ weights,
const int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
const offsets_t* const __restrict__ permuted_lengths,
const offsets_t* const __restrict__ output_offsets,
indices_t* const __restrict__ permuted_indices,
weights_t* const __restrict__ permuted_weights) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
offsets_t permuted_length = permuted_lengths[tb];
const offsets_t input_start = input_offsets[permute[tb]];
const offsets_t output_start = output_offsets[tb];
for (const auto i : c10::irange(permuted_length)) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}); // parallel_for T x B, different B across T
}
std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum) {
TENSOR_ON_CPU(permute);
TENSOR_ON_CPU(lengths);
TENSOR_ON_CPU(indices);
TENSOR_ON_CPU(weights);
const auto permute_contig = permute.expect_contiguous();
const auto lengths_contig = lengths.expect_contiguous();
const auto indices_contig = indices.expect_contiguous();
// the data to permute over can be less or more with or without
// repetitions
Tensor permuted_lengths;
Tensor permuted_indices;
Tensor permuted_weights;
const auto permuted_lengths_size = permute.numel();
permuted_lengths = at::empty({permuted_lengths_size}, lengths.options());
int num_threads = at::get_num_threads();
std::vector<int64_t> output_offsets_per_thread_cumsum(
(num_threads + 1) * FALSE_SHARING_PAD, 0);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] {
_permute_1D_lengths_cpu_kernel(
lengths_contig->data_ptr<index_t>(),
permuted_lengths_size,
permute_contig->data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
}); // for each scalar_t
const auto input_offsets = asynchronous_exclusive_cumsum_cpu(lengths);
const auto output_offsets =
asynchronous_complete_cumsum_cpu(permuted_lengths);
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();
} else {
permuted_indices_size =
output_offsets[permuted_lengths_size].item<int64_t>();
}
permuted_indices = at::empty(permuted_indices_size, indices.options());
AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_1D_indices_weights_kernel_1", [&] {
using offsets_t = index_t;
AT_DISPATCH_ALL_TYPES(
indices.scalar_type(), "permute_1D_indices_weights_kernel_2", [&] {
using indices_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES(
weights.has_value() ? weights.value().scalar_type()
: at::ScalarType::Float,
"permute_1D_indices_weights_kernel_3",
[&] {
using weights_t = scalar_t;
if (weights.has_value()) {
const auto weights_value_contig =
weights.value().expect_contiguous();
permuted_weights = at::empty(
permuted_indices_size, weights.value().options());
_permute_1D_indices_weights_kernel_cpu<
true,
index_t,
indices_t,
weights_t>(
input_offsets.data_ptr<offsets_t>(),
indices_contig->data_ptr<indices_t>(),
weights_value_contig->data_ptr<weights_t>(),
permuted_lengths_size,
permute_contig->data_ptr<int32_t>(),
permuted_lengths.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
} else {
_permute_1D_indices_weights_kernel_cpu<
false,
index_t,
indices_t,
weights_t>(
input_offsets.data_ptr<offsets_t>(),
indices_contig->data_ptr<indices_t>(),
nullptr,
permuted_lengths_size,
permute_contig->data_ptr<int32_t>(),
permuted_lengths.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
nullptr);
}
}); // for each weights_t
}); // for each indices_t
}); // for each offsets_t
return {permuted_lengths, permuted_indices, permuted_weights};
}
template <typename index_t, typename offsets_t>
void _expand_into_jagged_permute_cpu_kernel(
const offsets_t* const __restrict__ input_offsets,
const offsets_t* const __restrict__ output_offsets,
const int64_t permute_size,
const index_t* const __restrict__ permute,
index_t* const __restrict__ output_permute) {
at::parallel_for(
0, permute_size, FALSE_SHARING_PAD, [&](int64_t t_begin, int64_t t_end) {
for (int t = t_begin; t < std::min(t_end, permute_size); ++t) {
offsets_t permute_length = output_offsets[t + 1] - output_offsets[t];
const offsets_t input_start = input_offsets[permute[t]];
const offsets_t output_start = output_offsets[t];
for (const auto i : c10::irange(permute_length)) {
output_permute[output_start + i] = input_start + i;
}
}
}); // parallel_for T
}
Tensor expand_into_jagged_permute_cpu(
const Tensor& permute,
const Tensor& input_offsets,
const Tensor& output_offsets,
int64_t output_size) {
TENSOR_ON_CPU(permute);
TENSOR_ON_CPU(input_offsets);
TENSOR_ON_CPU(output_offsets);
TORCH_CHECK(permute.numel() > 0);
TORCH_CHECK(permute.numel() == input_offsets.numel() - 1);
TORCH_CHECK(permute.numel() == output_offsets.numel() - 1);
const auto permute_contig = permute.contiguous();
const auto permute_size = permute.numel();
Tensor output_permute = at::empty({output_size}, input_offsets.options());
AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "expand_into_jagged_permute_cpu", [&] {
using offset_t = index_t;
_expand_into_jagged_permute_cpu_kernel(
input_offsets.data_ptr<offset_t>(),
output_offsets.data_ptr<offset_t>(),
permute_size,
permute.data_ptr<index_t>(),
output_permute.data_ptr<index_t>());
});
return output_permute;
}
template <typename index_t>
void _invert_permute_cpu_kernel(
const int64_t permute_size,
const index_t* const __restrict__ permute,
index_t* const __restrict__ inversed_permute) {
at::parallel_for(
0, permute_size, FALSE_SHARING_PAD, [&](int64_t t_begin, int64_t t_end) {
for (int t = t_begin; t < std::min(t_end, permute_size); ++t) {
inversed_permute[permute[t]] = t;
}
});
}
Tensor invert_permute_cpu(const Tensor& permute) {
TENSOR_ON_CPU(permute);
const auto permute_contig = permute.expect_contiguous();
const auto permute_size = permute.numel();
Tensor inversed_permute = at::empty_like(permute);
AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "invert_permute_cpu_kernel", [&] {
_invert_permute_cpu_kernel<index_t>(
permute_size,
permute_contig->data_ptr<index_t>(),
inversed_permute.data_ptr<index_t>());
}); // for each scalar_t
return inversed_permute;
}
std::tuple<
Tensor,
Tensor,
c10::optional<Tensor>,
c10::optional<Tensor>,
c10::optional<Tensor>>
block_bucketize_sparse_features_cpu(
const Tensor& lengths,
const Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const Tensor& block_sizes,
const int64_t my_size,
const c10::optional<Tensor>& weights,
const c10::optional<Tensor>& batch_size_per_feature,
const int64_t /* max_batch_size */, // Only used in GPU variant
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
auto new_indices = native_empty_like(indices);
Tensor new_weights;
Tensor new_pos;
Tensor unbucketize_permute;
if (bucketize_pos) {
new_pos = native_empty_like(indices);
}
if (weights.has_value()) {
const auto lengths_sum = indices.numel();
Tensor weights_value = weights.value();
new_weights = native_empty_like(weights_value);
if (sequence) {
unbucketize_permute = at::empty({lengths_sum}, indices.options());
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(),
"block_bucketize_sparse_features_weights_cpu_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"block_bucketize_sparse_features_weights_cpu_2",
[&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"bucketize_sparse_features_weights_cpu_3",
[&] {
_block_bucketize_sparse_features_cpu<
true,
true,
offset_t,
index_t,
scalar_t>(
lengths,
indices,
weights,
bucketize_pos,
block_sizes,
my_size,
new_lengths,
new_indices,
new_weights,
new_pos,
unbucketize_permute,
batch_size_per_feature,
block_bucketize_pos);
});
});
});
} else {
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(),