Skip to content

Commit 138e146

Browse files
authored
Fix prediction with sparse QDM. (dmlc#11250)
1 parent 258aed9 commit 138e146

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

src/predictor/cpu_predictor.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
#include <typeinfo> // for type_info
1212
#include <vector> // for vector
1313

14-
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
15-
#include "../collective/allreduce.h"
14+
#include "../collective/allreduce.h" // for Allreduce
15+
#include "../collective/communicator-inl.h" // for IsDistributed
1616
#include "../common/bitfield.h" // for RBitField8
17+
#include "../common/column_matrix.h" // for ColumnMatrix
1718
#include "../common/common.h" // for DivRoundUp
1819
#include "../common/error_msg.h" // for InplacePredictProxy
1920
#include "../common/math.h" // for CheckNAN
@@ -195,6 +196,7 @@ struct GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView> {
195196
std::vector<std::uint32_t> const &ptrs_;
196197
std::vector<float> const &mins_;
197198
std::vector<float> const &values_;
199+
common::ColumnMatrix const &columns_;
198200

199201
public:
200202
bst_idx_t const base_rowid;
@@ -206,6 +208,7 @@ struct GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView> {
206208
ptrs_{_page.cut.Ptrs()},
207209
mins_{_page.cut.MinValues()},
208210
values_{_page.cut.Values()},
211+
columns_{page_.Transpose()},
209212
base_rowid{_page.base_rowid} {}
210213

211214
[[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const {
@@ -235,7 +238,22 @@ struct GHistIndexMatrixView : public DataToFeatVec<GHistIndexMatrixView> {
235238
n_non_missings += n_features;
236239
} else {
237240
for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) {
238-
float f = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, common::IsCat(ft_, fidx));
241+
float f = std::numeric_limits<float>::quiet_NaN();
242+
bool is_cat = common::IsCat(ft_, fidx);
243+
if (columns_.GetColumnType(fidx) == common::kSparseColumn) {
244+
// Special handling for extremely sparse data. Just binary search.
245+
auto bin_idx = page_.GetGindex(gridx, fidx);
246+
if (bin_idx != -1) {
247+
if (is_cat) {
248+
f = values_[bin_idx];
249+
} else {
250+
f = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx,
251+
bin_idx);
252+
}
253+
}
254+
} else {
255+
f = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat);
256+
}
239257
if (!common::CheckNAN(f)) {
240258
out[fidx] = f;
241259
n_non_missings++;

tests/python/test_quantile_dmatrix.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,23 @@ def test_changed_max_bin(self) -> None:
375375

376376
def test_mixed_sparsity(self) -> None:
377377
run_mixed_sparsity("cpu")
378+
379+
def test_sparse_predict(self) -> None:
380+
X, y = make_sparse_regression(512, 16, sparsity=0.9, as_dense=False)
381+
382+
Xy: xgb.DMatrix = xgb.QuantileDMatrix(X, y)
383+
booster = xgb.train({}, Xy, num_boost_round=8)
384+
385+
p0 = booster.predict(Xy)
386+
Xy = xgb.DMatrix(X, y)
387+
p1 = booster.predict(Xy)
388+
np.testing.assert_allclose(p0, p1)
389+
390+
X, y = make_categorical(128, 16, 5, onehot=False, sparsity=0.9)
391+
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
392+
booster = xgb.train({}, Xy, num_boost_round=8)
393+
394+
p0 = booster.predict(Xy)
395+
Xy = xgb.DMatrix(X, y, enable_categorical=True)
396+
p1 = booster.predict(Xy)
397+
np.testing.assert_allclose(p0, p1)

0 commit comments

Comments
 (0)