-
Notifications
You must be signed in to change notification settings - Fork 25
/
gpu_treeshap.h
1562 lines (1398 loc) · 61.6 KB
/
gpu_treeshap.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <thrust/copy.h>
#include <thrust/device_allocator.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/logical.h>
#include <thrust/pair.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <cub/cub.cuh>
#include <algorithm>
#include <functional>
#include <set>
#include <stdexcept>
#include <utility>
#include <vector>
namespace gpu_treeshap {
struct XgboostSplitCondition {
XgboostSplitCondition() = default;
XgboostSplitCondition(float feature_lower_bound, float feature_upper_bound,
bool is_missing_branch)
: feature_lower_bound(feature_lower_bound),
feature_upper_bound(feature_upper_bound),
is_missing_branch(is_missing_branch) {
assert(feature_lower_bound <= feature_upper_bound);
}
/*! Feature values >= lower and < upper flow down this path. */
float feature_lower_bound;
float feature_upper_bound;
/*! Do missing values flow down this path? */
bool is_missing_branch;
// Does this instance flow down this path?
__host__ __device__ bool EvaluateSplit(float x) const {
// is nan
if (isnan(x)) {
return is_missing_branch;
}
return x >= feature_lower_bound && x < feature_upper_bound;
}
// Combine two split conditions on the same feature
__host__ __device__ void Merge(
const XgboostSplitCondition& other) { // Combine duplicate features
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
is_missing_branch = is_missing_branch && other.is_missing_branch;
}
};
/*!
* An element of a unique path through a decision tree. Can implement various
* types of splits via the templated SplitConditionT. Some decision tree
* implementations may wish to use double precision or single precision, some
* may use < or <= as the threshold, missing values can be handled differently,
* categoricals may be supported.
*
* \tparam SplitConditionT A split condition implementing the methods
* EvaluateSplit and Merge.
*/
template <typename SplitConditionT>
struct PathElement {
using split_type = SplitConditionT;
__host__ __device__ PathElement(size_t path_idx, int64_t feature_idx,
int group, SplitConditionT split_condition,
double zero_fraction, float v)
: path_idx(path_idx),
feature_idx(feature_idx),
group(group),
split_condition(split_condition),
zero_fraction(zero_fraction),
v(v) {}
PathElement() = default;
__host__ __device__ bool IsRoot() const { return feature_idx == -1; }
template <typename DatasetT>
__host__ __device__ bool EvaluateSplit(DatasetT X, size_t row_idx) const {
if (this->IsRoot()) {
return 1.0;
}
return split_condition.EvaluateSplit(X.GetElement(row_idx, feature_idx));
}
/*! Unique path index. */
size_t path_idx;
/*! Feature of this split, -1 indicates bias term. */
int64_t feature_idx;
/*! Indicates class for multiclass problems. */
int group;
SplitConditionT split_condition;
/*! Probability of following this path when feature_idx is not in the active
* set. */
double zero_fraction;
float v; // Leaf weight at the end of the path
};
// Helper function that accepts an index into a flat contiguous array and the
// dimensions of a tensor and returns the indices with respect to the tensor
template <typename T, size_t N>
__device__ void FlatIdxToTensorIdx(T flat_idx, const T (&shape)[N],
T (&out_idx)[N]) {
T current_size = shape[0];
for (auto i = 1ull; i < N; i++) {
current_size *= shape[i];
}
for (auto i = 0ull; i < N; i++) {
current_size /= shape[i];
out_idx[i] = flat_idx / current_size;
flat_idx -= current_size * out_idx[i];
}
}
// Given a shape and coordinates into a tensor, return the index into the
// backing storage one-dimensional array
template <typename T, size_t N>
__device__ T TensorIdxToFlatIdx(const T (&shape)[N], const T (&tensor_idx)[N]) {
T current_size = shape[0];
for (auto i = 1ull; i < N; i++) {
current_size *= shape[i];
}
T idx = 0;
for (auto i = 0ull; i < N; i++) {
current_size /= shape[i];
idx += tensor_idx[i] * current_size;
}
return idx;
}
// Maps values to the phi array according to row, group and column
__host__ __device__ inline size_t IndexPhi(size_t row_idx, size_t num_groups,
size_t group, size_t num_columns,
size_t column_idx) {
return (row_idx * num_groups + group) * (num_columns + 1) + column_idx;
}
__host__ __device__ inline size_t IndexPhiInteractions(size_t row_idx,
size_t num_groups,
size_t group,
size_t num_columns,
size_t i, size_t j) {
size_t matrix_size = (num_columns + 1) * (num_columns + 1);
size_t matrix_offset = (row_idx * num_groups + group) * matrix_size;
return matrix_offset + i * (num_columns + 1) + j;
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
// Shorthand for creating a device vector with an appropriate allocator type
template <class T, class DeviceAllocatorT>
using RebindVector =
thrust::device_vector<T,
typename DeviceAllocatorT::template rebind<T>::other>;
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
__device__ __forceinline__ double atomicAddDouble(double* address, double val) {
return atomicAdd(address, val);
}
#else // In device code and CUDA < 600
__device__ __forceinline__ double atomicAddDouble(double* address,
double val) { // NOLINT
unsigned long long int* address_as_ull = // NOLINT
(unsigned long long int*)address; // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
// NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
__forceinline__ __device__ unsigned int lanemask32_lt() {
unsigned int lanemask32_lt;
asm volatile("mov.u32 %0, %%lanemask_lt;" : "=r"(lanemask32_lt));
return (lanemask32_lt);
}
// Like a coalesced group, except we can make the assumption that all threads in
// a group are next to each other. This makes shuffle operations much cheaper.
class ContiguousGroup {
public:
__device__ ContiguousGroup(uint32_t mask) : mask_(mask) {}
__device__ uint32_t size() const { return __popc(mask_); }
__device__ uint32_t thread_rank() const {
return __popc(mask_ & lanemask32_lt());
}
template <typename T>
__device__ T shfl(T val, uint32_t src) const {
return __shfl_sync(mask_, val, src + __ffs(mask_) - 1);
}
template <typename T>
__device__ T shfl_up(T val, uint32_t delta) const {
return __shfl_up_sync(mask_, val, delta);
}
__device__ uint32_t ballot(int predicate) const {
return __ballot_sync(mask_, predicate) >> (__ffs(mask_) - 1);
}
template <typename T, typename OpT>
__device__ T reduce(T val, OpT op) {
for (int i = 1; i < this->size(); i *= 2) {
T shfl = shfl_up(val, i);
if (static_cast<int>(thread_rank()) - i >= 0) {
val = op(val, shfl);
}
}
return shfl(val, size() - 1);
}
uint32_t mask_;
};
// Separate the active threads by labels
// This functionality is available in cuda 11.0 on cc >=7.0
// We reimplement for backwards compatibility
// Assumes partitions are contiguous
inline __device__ ContiguousGroup active_labeled_partition(uint32_t mask,
int label) {
#if __CUDA_ARCH__ >= 700
uint32_t subgroup_mask = __match_any_sync(mask, label);
#else
uint32_t subgroup_mask = 0;
for (int i = 0; i < 32;) {
int current_label = __shfl_sync(mask, label, i);
uint32_t ballot = __ballot_sync(mask, label == current_label);
if (label == current_label) {
subgroup_mask = ballot;
}
uint32_t completed_mask =
(1 << (32 - __clz(ballot))) - 1; // Threads that have finished
// Find the start of the next group, mask off completed threads from active
// threads Then use ffs - 1 to find the position of the next group
int next_i = __ffs(mask & ~completed_mask) - 1;
if (next_i == -1) break; // -1 indicates all finished
assert(next_i > i); // Prevent infinite loops when the constraints not met
i = next_i;
}
#endif
return ContiguousGroup(subgroup_mask);
}
// Group of threads where each thread holds a path element
class GroupPath {
protected:
const ContiguousGroup& g_;
// These are combined so we can communicate them in a single 64 bit shuffle
// instruction
float zero_one_fraction_[2];
float pweight_;
int unique_depth_;
public:
__device__ GroupPath(const ContiguousGroup& g, float zero_fraction,
float one_fraction)
: g_(g),
zero_one_fraction_{zero_fraction, one_fraction},
pweight_(g.thread_rank() == 0 ? 1.0f : 0.0f),
unique_depth_(0) {}
// Cooperatively extend the path with a group of threads
// Each thread maintains pweight for its path element in register
__device__ void Extend() {
unique_depth_++;
// Broadcast the zero and one fraction from the newly added path element
// Combine 2 shuffle operations into 64 bit word
const size_t rank = g_.thread_rank();
const float inv_unique_depth =
__fdividef(1.0f, static_cast<float>(unique_depth_ + 1));
uint64_t res = g_.shfl(*reinterpret_cast<uint64_t*>(&zero_one_fraction_),
unique_depth_);
const float new_zero_fraction = reinterpret_cast<float*>(&res)[0];
const float new_one_fraction = reinterpret_cast<float*>(&res)[1];
float left_pweight = g_.shfl_up(pweight_, 1);
// pweight of threads with rank < unique_depth_ is 0
// We use max(x,0) to avoid using a branch
// pweight_ *=
// new_zero_fraction * max(unique_depth_ - rank, 0llu) * inv_unique_depth;
pweight_ = __fmul_rn(
__fmul_rn(pweight_, new_zero_fraction),
__fmul_rn(max(unique_depth_ - rank, size_t(0)), inv_unique_depth));
// pweight_ += new_one_fraction * left_pweight * rank * inv_unique_depth;
pweight_ = __fmaf_rn(__fmul_rn(new_one_fraction, left_pweight),
__fmul_rn(rank, inv_unique_depth), pweight_);
}
// Each thread unwinds the path for its feature and returns the sum
__device__ float UnwoundPathSum() {
float next_one_portion = g_.shfl(pweight_, unique_depth_);
float total = 0.0f;
const float zero_frac_div_unique_depth = __fdividef(
zero_one_fraction_[0], static_cast<float>(unique_depth_ + 1));
for (int i = unique_depth_ - 1; i >= 0; i--) {
float ith_pweight = g_.shfl(pweight_, i);
float precomputed =
__fmul_rn((unique_depth_ - i), zero_frac_div_unique_depth);
const float tmp =
__fdividef(__fmul_rn(next_one_portion, unique_depth_ + 1), i + 1);
total = __fmaf_rn(tmp, zero_one_fraction_[1], total);
next_one_portion = __fmaf_rn(-tmp, precomputed, ith_pweight);
float numerator =
__fmul_rn(__fsub_rn(1.0f, zero_one_fraction_[1]), ith_pweight);
if (precomputed > 0.0f) {
total += __fdividef(numerator, precomputed);
}
}
return total;
}
};
// Has different permutation weightings to the above
// Used in Taylor Shapley interaction index
class TaylorGroupPath : GroupPath {
public:
__device__ TaylorGroupPath(const ContiguousGroup& g, float zero_fraction,
float one_fraction)
: GroupPath(g, zero_fraction, one_fraction) {}
// Extend the path is normal, all reweighting can happen in UnwoundPathSum
__device__ void Extend() { GroupPath::Extend(); }
// Each thread unwinds the path for its feature and returns the sum
// We use a different permutation weighting for Taylor interactions
// As if the total number of features was one larger
__device__ float UnwoundPathSum() {
float one_fraction = zero_one_fraction_[1];
float zero_fraction = zero_one_fraction_[0];
float next_one_portion = g_.shfl(pweight_, unique_depth_) /
static_cast<float>(unique_depth_ + 2);
float total = 0.0f;
for (int i = unique_depth_ - 1; i >= 0; i--) {
float ith_pweight =
g_.shfl(pweight_, i) * (static_cast<float>(unique_depth_ - i + 1) /
static_cast<float>(unique_depth_ + 2));
if (one_fraction > 0.0f) {
const float tmp =
next_one_portion * (unique_depth_ + 2) / ((i + 1) * one_fraction);
total += tmp;
next_one_portion =
ith_pweight - tmp * zero_fraction *
((unique_depth_ - i + 1) /
static_cast<float>(unique_depth_ + 2));
} else if (zero_fraction > 0.0f) {
total +=
(ith_pweight / zero_fraction) /
((unique_depth_ - i + 1) / static_cast<float>(unique_depth_ + 2));
}
}
return 2 * total;
}
};
template <typename DatasetT, typename SplitConditionT>
__device__ float ComputePhi(const PathElement<SplitConditionT>& e,
size_t row_idx, const DatasetT& X,
const ContiguousGroup& group, float zero_fraction) {
float one_fraction =
e.EvaluateSplit(X, row_idx);
GroupPath path(group, zero_fraction, one_fraction);
size_t unique_path_length = group.size();
// Extend the path
for (auto unique_depth = 1ull; unique_depth < unique_path_length;
unique_depth++) {
path.Extend();
}
float sum = path.UnwoundPathSum();
return sum * (one_fraction - zero_fraction) * e.v;
}
inline __host__ __device__ size_t DivRoundUp(size_t a, size_t b) {
return (a + b - 1) / b;
}
template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
typename SplitConditionT>
void __device__
ConfigureThread(const DatasetT& X, const size_t bins_per_row,
const PathElement<SplitConditionT>* path_elements,
const size_t* bin_segments, size_t* start_row, size_t* end_row,
PathElement<SplitConditionT>* e, bool* thread_active) {
// Partition work
// Each warp processes a set of training instances applied to a path
size_t tid = kBlockSize * blockIdx.x + threadIdx.x;
const size_t warp_size = 32;
size_t warp_rank = tid / warp_size;
if (warp_rank >= bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp)) {
*thread_active = false;
return;
}
size_t bin_idx = warp_rank % bins_per_row;
size_t bank = warp_rank / bins_per_row;
size_t path_start = bin_segments[bin_idx];
size_t path_end = bin_segments[bin_idx + 1];
uint32_t thread_rank = threadIdx.x % warp_size;
if (thread_rank >= path_end - path_start) {
*thread_active = false;
} else {
*e = path_elements[path_start + thread_rank];
*start_row = bank * kRowsPerWarp;
*end_row = min((bank + 1) * kRowsPerWarp, X.NumRows());
*thread_active = true;
}
}
#define GPUTREESHAP_MAX_THREADS_PER_BLOCK 256
#define FULL_MASK 0xffffffff
template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
typename SplitConditionT>
__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
ShapKernel(DatasetT X, size_t bins_per_row,
const PathElement<SplitConditionT>* path_elements,
const size_t* bin_segments, size_t num_groups, double* phis) {
// Use shared memory for structs, otherwise nvcc puts in local memory
__shared__ DatasetT s_X;
s_X = X;
__shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
PathElement<SplitConditionT>& e = s_elements[threadIdx.x];
size_t start_row, end_row;
bool thread_active;
ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e,
&thread_active);
uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
if (!thread_active) return;
float zero_fraction = e.zero_fraction;
auto labelled_group = active_labeled_partition(mask, e.path_idx);
for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
float phi = ComputePhi(e, row_idx, X, labelled_group, zero_fraction);
if (!e.IsRoot()) {
atomicAddDouble(&phis[IndexPhi(row_idx, num_groups, e.group, X.NumCols(),
e.feature_idx)],
phi);
}
}
}
template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
typename SplitConditionT>
void ComputeShap(
DatasetT X,
const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
path_elements,
size_t num_groups, double* phis) {
size_t bins_per_row = bin_segments.size() - 1;
const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
const int warps_per_block = kBlockThreads / 32;
const int kRowsPerWarp = 1024;
size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
ShapKernel<DatasetT, kBlockThreads, kRowsPerWarp>
<<<grid_size, kBlockThreads>>>(
X, bins_per_row, path_elements.data().get(),
bin_segments.data().get(), num_groups, phis);
}
template <typename PathT, typename DatasetT, typename SplitConditionT>
__device__ float ComputePhiCondition(const PathElement<SplitConditionT>& e,
size_t row_idx, const DatasetT& X,
const ContiguousGroup& group,
int64_t condition_feature) {
float one_fraction = e.EvaluateSplit(X, row_idx);
PathT path(group, e.zero_fraction, one_fraction);
size_t unique_path_length = group.size();
float condition_on_fraction = 1.0f;
float condition_off_fraction = 1.0f;
// Extend the path
for (auto i = 1ull; i < unique_path_length; i++) {
bool is_condition_feature =
group.shfl(e.feature_idx, i) == condition_feature;
float o_i = group.shfl(one_fraction, i);
float z_i = group.shfl(e.zero_fraction, i);
if (is_condition_feature) {
condition_on_fraction = o_i;
condition_off_fraction = z_i;
} else {
path.Extend();
}
}
float sum = path.UnwoundPathSum();
if (e.feature_idx == condition_feature) {
return 0.0f;
}
float phi = sum * (one_fraction - e.zero_fraction) * e.v;
return phi * (condition_on_fraction - condition_off_fraction) * 0.5f;
}
// If there is a feature in the path we are conditioning on, swap it to the end
// of the path
template <typename SplitConditionT>
inline __device__ void SwapConditionedElement(
PathElement<SplitConditionT>** e, PathElement<SplitConditionT>* s_elements,
uint32_t condition_rank, const ContiguousGroup& group) {
auto last_rank = group.size() - 1;
auto this_rank = group.thread_rank();
if (this_rank == last_rank) {
*e = &s_elements[(threadIdx.x - this_rank) + condition_rank];
} else if (this_rank == condition_rank) {
*e = &s_elements[(threadIdx.x - this_rank) + last_rank];
}
}
template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
typename SplitConditionT>
__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
ShapInteractionsKernel(DatasetT X, size_t bins_per_row,
const PathElement<SplitConditionT>* path_elements,
const size_t* bin_segments, size_t num_groups,
double* phis_interactions) {
// Use shared memory for structs, otherwise nvcc puts in local memory
__shared__ DatasetT s_X;
s_X = X;
__shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
PathElement<SplitConditionT>* e = &s_elements[threadIdx.x];
size_t start_row, end_row;
bool thread_active;
ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e,
&thread_active);
uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
if (!thread_active) return;
auto labelled_group = active_labeled_partition(mask, e->path_idx);
for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
float phi = ComputePhi(*e, row_idx, X, labelled_group, e->zero_fraction);
if (!e->IsRoot()) {
auto phi_offset =
IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
e->feature_idx, e->feature_idx);
atomicAddDouble(phis_interactions + phi_offset, phi);
}
for (auto condition_rank = 1ull; condition_rank < labelled_group.size();
condition_rank++) {
e = &s_elements[threadIdx.x];
int64_t condition_feature =
labelled_group.shfl(e->feature_idx, condition_rank);
SwapConditionedElement(&e, s_elements, condition_rank, labelled_group);
float x = ComputePhiCondition<GroupPath>(*e, row_idx, X, labelled_group,
condition_feature);
if (!e->IsRoot()) {
auto phi_offset =
IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
e->feature_idx, condition_feature);
atomicAddDouble(phis_interactions + phi_offset, x);
// Subtract effect from diagonal
auto phi_diag =
IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
e->feature_idx, e->feature_idx);
atomicAddDouble(phis_interactions + phi_diag, -x);
}
}
}
}
template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
typename SplitConditionT>
void ComputeShapInteractions(
DatasetT X,
const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
path_elements,
size_t num_groups, double* phis) {
size_t bins_per_row = bin_segments.size() - 1;
const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
const int warps_per_block = kBlockThreads / 32;
const int kRowsPerWarp = 100;
size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
ShapInteractionsKernel<DatasetT, kBlockThreads, kRowsPerWarp>
<<<grid_size, kBlockThreads>>>(
X, bins_per_row, path_elements.data().get(),
bin_segments.data().get(), num_groups, phis);
}
template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
typename SplitConditionT>
__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
ShapTaylorInteractionsKernel(
DatasetT X, size_t bins_per_row,
const PathElement<SplitConditionT>* path_elements,
const size_t* bin_segments, size_t num_groups,
double* phis_interactions) {
// Use shared memory for structs, otherwise nvcc puts in local memory
__shared__ DatasetT s_X;
if (threadIdx.x == 0) {
s_X = X;
}
__syncthreads();
__shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
PathElement<SplitConditionT>* e = &s_elements[threadIdx.x];
size_t start_row, end_row;
bool thread_active;
ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e,
&thread_active);
uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
if (!thread_active) return;
auto labelled_group = active_labeled_partition(mask, e->path_idx);
for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
for (auto condition_rank = 1ull; condition_rank < labelled_group.size();
condition_rank++) {
e = &s_elements[threadIdx.x];
// Compute the diagonal terms
// TODO(Rory): this can be more efficient
float reduce_input =
e->IsRoot() || labelled_group.thread_rank() == condition_rank
? 1.0f
: e->zero_fraction;
float reduce =
labelled_group.reduce(reduce_input, thrust::multiplies<float>());
if (labelled_group.thread_rank() == condition_rank) {
float one_fraction = e->split_condition.EvaluateSplit(
X.GetElement(row_idx, e->feature_idx));
auto phi_offset =
IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
e->feature_idx, e->feature_idx);
atomicAddDouble(phis_interactions + phi_offset,
reduce * (one_fraction - e->zero_fraction) * e->v);
}
int64_t condition_feature =
labelled_group.shfl(e->feature_idx, condition_rank);
SwapConditionedElement(&e, s_elements, condition_rank, labelled_group);
float x = ComputePhiCondition<TaylorGroupPath>(
*e, row_idx, X, labelled_group, condition_feature);
if (!e->IsRoot()) {
auto phi_offset =
IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
e->feature_idx, condition_feature);
atomicAddDouble(phis_interactions + phi_offset, x);
}
}
}
}
template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
typename SplitConditionT>
void ComputeShapTaylorInteractions(
DatasetT X,
const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
path_elements,
size_t num_groups, double* phis) {
size_t bins_per_row = bin_segments.size() - 1;
const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
const int warps_per_block = kBlockThreads / 32;
const int kRowsPerWarp = 100;
size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
ShapTaylorInteractionsKernel<DatasetT, kBlockThreads, kRowsPerWarp>
<<<grid_size, kBlockThreads>>>(
X, bins_per_row, path_elements.data().get(),
bin_segments.data().get(), num_groups, phis);
}
inline __host__ __device__ int64_t Factorial(int64_t x) {
int64_t y = 1;
for (auto i = 2; i <= x; i++) {
y *= i;
}
return y;
}
// Compute factorials in log space using lgamma to avoid overflow
inline __host__ __device__ double W(double s, double n) {
assert(n - s - 1 >= 0);
return exp(lgamma(s + 1) - lgamma(n + 1) + lgamma(n - s));
}
template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
typename SplitConditionT>
__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
ShapInterventionalKernel(DatasetT X, DatasetT R, size_t bins_per_row,
const PathElement<SplitConditionT>* path_elements,
const size_t* bin_segments, size_t num_groups,
double* phis) {
// Cache W coefficients
__shared__ float s_W[33][33];
for (int i = threadIdx.x; i < 33 * 33; i += kBlockSize) {
auto s = i % 33;
auto n = i / 33;
if (n - s - 1 >= 0) {
s_W[s][n] = W(s, n);
} else {
s_W[s][n] = 0.0;
}
}
__syncthreads();
__shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
PathElement<SplitConditionT>& e = s_elements[threadIdx.x];
size_t start_row, end_row;
bool thread_active;
ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e,
&thread_active);
uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
if (!thread_active) return;
auto labelled_group = active_labeled_partition(mask, e.path_idx);
for (int64_t x_idx = start_row; x_idx < end_row; x_idx++) {
float result = 0.0f;
bool x_cond = e.EvaluateSplit(X, x_idx);
uint32_t x_ballot = labelled_group.ballot(x_cond);
for (int64_t r_idx = 0; r_idx < R.NumRows(); r_idx++) {
bool r_cond = e.EvaluateSplit(R, r_idx);
uint32_t r_ballot = labelled_group.ballot(r_cond);
assert(!e.IsRoot() ||
(x_cond == r_cond)); // These should be the same for the root
uint32_t s = __popc(x_ballot & ~r_ballot);
uint32_t n = __popc(x_ballot ^ r_ballot);
float tmp = 0.0f;
// Theorem 1
if (x_cond && !r_cond) {
tmp += s_W[s - 1][n];
}
tmp -= s_W[s][n] * (r_cond && !x_cond);
// No foreground samples make it to this leaf, increment bias
if (e.IsRoot() && s == 0) {
tmp += 1.0f;
}
// If neither foreground or background go down this path, ignore this path
bool reached_leaf = !labelled_group.ballot(!x_cond && !r_cond);
tmp *= reached_leaf;
result += tmp;
}
if (result != 0.0) {
result /= R.NumRows();
// Root writes bias
auto feature = e.IsRoot() ? X.NumCols() : e.feature_idx;
atomicAddDouble(
&phis[IndexPhi(x_idx, num_groups, e.group, X.NumCols(), feature)],
result * e.v);
}
}
}
template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
typename SplitConditionT>
void ComputeShapInterventional(
DatasetT X, DatasetT R,
const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
path_elements,
size_t num_groups, double* phis) {
size_t bins_per_row = bin_segments.size() - 1;
const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
const int warps_per_block = kBlockThreads / 32;
const int kRowsPerWarp = 100;
size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
ShapInterventionalKernel<DatasetT, kBlockThreads, kRowsPerWarp>
<<<grid_size, kBlockThreads>>>(
X, R, bins_per_row, path_elements.data().get(),
bin_segments.data().get(), num_groups, phis);
}
template <typename PathVectorT, typename SizeVectorT, typename DeviceAllocatorT>
void GetBinSegments(const PathVectorT& paths, const SizeVectorT& bin_map,
SizeVectorT* bin_segments) {
DeviceAllocatorT alloc;
size_t num_bins =
thrust::reduce(thrust::cuda::par(alloc), bin_map.begin(), bin_map.end(),
size_t(0), thrust::maximum<size_t>()) +
1;
bin_segments->resize(num_bins + 1, 0);
auto counting = thrust::make_counting_iterator(0llu);
auto d_paths = paths.data().get();
auto d_bin_segments = bin_segments->data().get();
auto d_bin_map = bin_map.data();
thrust::for_each_n(counting, paths.size(), [=] __device__(size_t idx) {
auto path_idx = d_paths[idx].path_idx;
atomicAdd(reinterpret_cast<unsigned long long*>(d_bin_segments) + // NOLINT
d_bin_map[path_idx],
1);
});
thrust::exclusive_scan(thrust::cuda::par(alloc), bin_segments->begin(),
bin_segments->end(), bin_segments->begin());
}
struct DeduplicateKeyTransformOp {
template <typename SplitConditionT>
__device__ thrust::pair<size_t, int64_t> operator()(
const PathElement<SplitConditionT>& e) {
return {e.path_idx, e.feature_idx};
}
};
inline void CheckCuda(cudaError_t err) {
if (err != cudaSuccess) {
throw thrust::system_error(err, thrust::cuda_category());
}
}
template <typename Return>
class DiscardOverload : public thrust::discard_iterator<Return> {
public:
using value_type = Return; // NOLINT
};
template <typename PathVectorT, typename DeviceAllocatorT,
typename SplitConditionT>
void DeduplicatePaths(PathVectorT* device_paths,
PathVectorT* deduplicated_paths) {
DeviceAllocatorT alloc;
// Sort by feature
thrust::sort(thrust::cuda::par(alloc), device_paths->begin(),
device_paths->end(),
[=] __device__(const PathElement<SplitConditionT>& a,
const PathElement<SplitConditionT>& b) {
if (a.path_idx < b.path_idx) return true;
if (b.path_idx < a.path_idx) return false;
if (a.feature_idx < b.feature_idx) return true;
if (b.feature_idx < a.feature_idx) return false;
return false;
});
deduplicated_paths->resize(device_paths->size());
using Pair = thrust::pair<size_t, int64_t>;
auto key_transform = thrust::make_transform_iterator(
device_paths->begin(), DeduplicateKeyTransformOp());
thrust::device_vector<size_t> d_num_runs_out(1);
size_t* h_num_runs_out;
CheckCuda(cudaMallocHost(&h_num_runs_out, sizeof(size_t)));
auto combine = [] __host__ __device__(PathElement<SplitConditionT> a,
PathElement<SplitConditionT> b) {
// Combine duplicate features
a.split_condition.Merge(b.split_condition);
a.zero_fraction *= b.zero_fraction;
return a;
}; // NOLINT
size_t temp_size = 0;
CheckCuda(cub::DeviceReduce::ReduceByKey(
nullptr, temp_size, key_transform, DiscardOverload<Pair>(),
device_paths->begin(), deduplicated_paths->begin(),
d_num_runs_out.begin(), combine, device_paths->size()));
using TempAlloc = RebindVector<char, DeviceAllocatorT>;
TempAlloc tmp(temp_size);
CheckCuda(cub::DeviceReduce::ReduceByKey(
tmp.data().get(), temp_size, key_transform, DiscardOverload<Pair>(),
device_paths->begin(), deduplicated_paths->begin(),
d_num_runs_out.begin(), combine, device_paths->size()));
CheckCuda(cudaMemcpy(h_num_runs_out, d_num_runs_out.data().get(),
sizeof(size_t), cudaMemcpyDeviceToHost));
deduplicated_paths->resize(*h_num_runs_out);
CheckCuda(cudaFreeHost(h_num_runs_out));
}
template <typename PathVectorT, typename SplitConditionT, typename SizeVectorT,
typename DeviceAllocatorT>
void SortPaths(PathVectorT* paths, const SizeVectorT& bin_map) {
auto d_bin_map = bin_map.data();
DeviceAllocatorT alloc;
thrust::sort(thrust::cuda::par(alloc), paths->begin(), paths->end(),
[=] __device__(const PathElement<SplitConditionT>& a,
const PathElement<SplitConditionT>& b) {
size_t a_bin = d_bin_map[a.path_idx];
size_t b_bin = d_bin_map[b.path_idx];
if (a_bin < b_bin) return true;
if (b_bin < a_bin) return false;
if (a.path_idx < b.path_idx) return true;
if (b.path_idx < a.path_idx) return false;
if (a.feature_idx < b.feature_idx) return true;
if (b.feature_idx < a.feature_idx) return false;
return false;
});
}
using kv = std::pair<size_t, int>;
struct BFDCompare {
bool operator()(const kv& lhs, const kv& rhs) const {
if (lhs.second == rhs.second) {
return lhs.first < rhs.first;
}
return lhs.second < rhs.second;
}
};
// Best Fit Decreasing bin packing
// Efficient O(nlogn) implementation with balanced tree using std::set
template <typename IntVectorT>
std::vector<size_t> BFDBinPacking(const IntVectorT& counts,
int bin_limit = 32) {
thrust::host_vector<int> counts_host(counts);
std::vector<kv> path_lengths(counts_host.size());
for (auto i = 0ull; i < counts_host.size(); i++) {
path_lengths[i] = {i, counts_host[i]};
}
std::sort(path_lengths.begin(), path_lengths.end(),
[&](const kv& a, const kv& b) {
std::greater<> op;
return op(a.second, b.second);
});
// map unique_id -> bin
std::vector<size_t> bin_map(counts_host.size());
std::set<kv, BFDCompare> bin_capacities;
bin_capacities.insert({bin_capacities.size(), bin_limit});
for (auto pair : path_lengths) {
int new_size = pair.second;
auto itr = bin_capacities.lower_bound({0, new_size});
// Does not fit in any bin
if (itr == bin_capacities.end()) {
size_t new_bin_idx = bin_capacities.size();
bin_capacities.insert({new_bin_idx, bin_limit - new_size});
bin_map[pair.first] = new_bin_idx;
} else {
kv entry = *itr;
entry.second -= new_size;
bin_map[pair.first] = entry.first;
bin_capacities.erase(itr);
bin_capacities.insert(entry);
}