11
11
#include < typeinfo> // for type_info
12
12
#include < vector> // for vector
13
13
14
- #include " ../collective/communicator-inl .h" // for Allreduce, IsDistributed
15
- #include " ../collective/allreduce .h"
14
+ #include " ../collective/allreduce .h" // for Allreduce
15
+ #include " ../collective/communicator-inl .h" // for IsDistributed
16
16
#include " ../common/bitfield.h" // for RBitField8
17
+ #include " ../common/column_matrix.h" // for ColumnMatrix
17
18
#include " ../common/common.h" // for DivRoundUp
18
19
#include " ../common/error_msg.h" // for InplacePredictProxy
19
20
#include " ../common/math.h" // for CheckNAN
@@ -195,6 +196,7 @@ struct GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView> {
195
196
std::vector<std::uint32_t > const &ptrs_;
196
197
std::vector<float > const &mins_;
197
198
std::vector<float > const &values_;
199
+ common::ColumnMatrix const &columns_;
198
200
199
201
public:
200
202
bst_idx_t const base_rowid;
@@ -206,6 +208,7 @@ struct GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView> {
206
208
ptrs_{_page.cut .Ptrs ()},
207
209
mins_{_page.cut .MinValues ()},
208
210
values_{_page.cut .Values ()},
211
+ columns_{page_.Transpose ()},
209
212
base_rowid{_page.base_rowid } {}
210
213
211
214
[[nodiscard]] bst_idx_t DoFill (bst_idx_t ridx, float * out) const {
@@ -235,7 +238,22 @@ struct GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView> {
235
238
n_non_missings += n_features;
236
239
} else {
237
240
for (bst_feature_t fidx = 0 ; fidx < n_features; ++fidx) {
238
- float f = page_.GetFvalue (ptrs_, values_, mins_, gridx, fidx, common::IsCat (ft_, fidx));
241
+ float f = std::numeric_limits<float >::quiet_NaN ();
242
+ bool is_cat = common::IsCat (ft_, fidx);
243
+ if (columns_.GetColumnType (fidx) == common::kSparseColumn ) {
244
+ // Special handling for extremely sparse data. Just binary search.
245
+ auto bin_idx = page_.GetGindex (gridx, fidx);
246
+ if (bin_idx != -1 ) {
247
+ if (is_cat) {
248
+ f = values_[bin_idx];
249
+ } else {
250
+ f = common::HistogramCuts::NumericBinValue (this ->ptrs_ , values_, mins_, fidx,
251
+ bin_idx);
252
+ }
253
+ }
254
+ } else {
255
+ f = page_.GetFvalue (ptrs_, values_, mins_, gridx, fidx, is_cat);
256
+ }
239
257
if (!common::CheckNAN (f)) {
240
258
out[fidx] = f;
241
259
n_non_missings++;
0 commit comments