Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c++] Initial Work for Pairwise Ranking #6182

Open
wants to merge 88 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
9ae3476
initial work for pairwise ranking (dataset part)
shiyu1994 Nov 8, 2023
2314099
remove unrelated changes
shiyu1994 Nov 8, 2023
06ddf68
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Nov 8, 2023
42e91e2
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Nov 23, 2023
a8379d4
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Dec 1, 2023
da5f02d
first version of pairwie ranking bin
shiyu1994 Dec 5, 2023
9d0afd9
Merge branch 'pairwise-ranking-dev' of https://github.com/Microsoft/L…
shiyu1994 Dec 5, 2023
0cb436d
templates for bins in pairwise ranking dataset
shiyu1994 Dec 5, 2023
fc9b381
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Dec 5, 2023
6fbc674
fix lint issues and compilation errors
shiyu1994 Dec 6, 2023
6082913
Merge branch 'pairwise-ranking-dev' of https://github.com/Microsoft/L…
shiyu1994 Dec 6, 2023
9e16dc3
add methods for pairwise bin
shiyu1994 Dec 6, 2023
6154bde
instantiate templates
shiyu1994 Dec 6, 2023
3a646eb
remove unrelated files
shiyu1994 Dec 6, 2023
9e77ab9
add return values for unimplemented methods
shiyu1994 Dec 7, 2023
eba4560
add new files and windows/LightGBM.vcxproj and windows/LightGBM.vcxpr…
shiyu1994 Dec 7, 2023
f1d2281
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Dec 7, 2023
873d7ad
create pairwise dataset
shiyu1994 Dec 7, 2023
3838b9b
Merge branch 'pairwise-ranking-dev' of https://github.com/Microsoft/L…
shiyu1994 Dec 7, 2023
986a979
set num_data_ of pairwise dataset
shiyu1994 Dec 7, 2023
c40965a
skip query with no paired items
shiyu1994 Dec 15, 2023
97d34d7
store original query information
shiyu1994 Jan 31, 2024
1e57e27
copy position information for pairwise dataset
shiyu1994 Jan 31, 2024
1699c06
rename to pointwise members
shiyu1994 Feb 1, 2024
d5b6f0a
adding initial support for pairwise gradients and NDCG eval with pair…
metpavel Feb 9, 2024
2ee1199
fix score offsets
metpavel Feb 9, 2024
fe10a2c
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Feb 19, 2024
0aaf090
skip copy for weights and label if none
shiyu1994 Feb 19, 2024
8714bfb
fix pairwise dataset bugs
shiyu1994 Feb 29, 2024
250996b
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Feb 29, 2024
38b2f3e
fix validation set with pairwise lambda rank
shiyu1994 Feb 29, 2024
09fff25
Merge branch 'pairwise-ranking-dev' of https://github.com/Microsoft/L…
shiyu1994 Feb 29, 2024
ba3c815
fix pairwise ranking objective initialization
shiyu1994 Feb 29, 2024
d9b537d
keep the original query boundaries and add pairwise query boundaries
shiyu1994 Feb 29, 2024
362baf8
allow empty queries in pairwise query boundaries
shiyu1994 Mar 1, 2024
06597ac
fix query boundaries
shiyu1994 Mar 1, 2024
18e3a1b
clean up
shiyu1994 Mar 1, 2024
43b8582
various fixes
metpavel Mar 1, 2024
ad4e89f
construct all pairs for validation set
shiyu1994 Mar 1, 2024
dc17309
Merge branch 'pairwise-ranking-dev' of https://github.com/microsoft/L…
metpavel Mar 1, 2024
1ad78b2
fix for validation set
shiyu1994 Mar 1, 2024
9cd3b93
fix validation pairs
shiyu1994 Mar 1, 2024
f9d9c07
fatal error when no query boundary is provided
shiyu1994 Mar 1, 2024
97e0a81
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Mar 1, 2024
746bc82
add differential features
shiyu1994 Mar 8, 2024
f9ab075
add differential features
shiyu1994 Mar 20, 2024
7aa170b
bug fixing and efficiency improvement
metpavel Mar 25, 2024
abdb716
add feature group for differential features
shiyu1994 Mar 27, 2024
3cdfd83
refactor template initializations with macro
shiyu1994 Mar 28, 2024
3703495
tree learning with differential features
shiyu1994 Mar 28, 2024
8f55a93
avoid copy sampled values
shiyu1994 Mar 28, 2024
8c3e7be
fix sampled indices
shiyu1994 Apr 2, 2024
5aa2d17
push data into differential features
shiyu1994 Apr 11, 2024
1c319b8
fix differential feature bugs
shiyu1994 Apr 17, 2024
d8eb68b
clean up debug code
shiyu1994 Apr 17, 2024
b088236
fix validation set with differential features
shiyu1994 Apr 18, 2024
2d09897
support row-wise histogram construction with pairwise ranking
shiyu1994 Jun 15, 2024
406d0c1
fix row wise in pairwise ranking
shiyu1994 Jun 20, 2024
6c65d1f
save for debug
shiyu1994 Jun 20, 2024
7738915
update code for debug
shiyu1994 Jun 28, 2024
d6c16df
save changes
shiyu1994 Jul 4, 2024
0d572d7
save changes for debug
shiyu1994 Jul 8, 2024
1f59f85
save changes
shiyu1994 Aug 21, 2024
0618bb2
add bagging by query for lambdarank
shiyu1994 Aug 27, 2024
185bdf6
Merge branch 'master' into bagging/bagging-by-query-for-lambdarank
shiyu1994 Aug 27, 2024
38fa4c2
fix pre-commit
shiyu1994 Aug 27, 2024
2fce147
Merge branch 'bagging/bagging-by-query-for-lambdarank' of https://git…
shiyu1994 Aug 27, 2024
1f7f967
Merge branch 'master' into bagging/bagging-by-query-for-lambdarank
shiyu1994 Aug 29, 2024
9e2a322
fix bagging by query with cuda
shiyu1994 Aug 29, 2024
666c51e
fix bagging by query test case
shiyu1994 Aug 30, 2024
9e2c338
fix bagging by query test case
shiyu1994 Aug 30, 2024
3abbc11
fix bagging by query test case
shiyu1994 Aug 30, 2024
13fa0a3
add #include <vector>
shiyu1994 Aug 30, 2024
b8427b0
merge bagging by query
shiyu1994 Sep 4, 2024
0258f07
update CMakeLists.txt
shiyu1994 Sep 4, 2024
90a95fa
fix bagging by query with pairwise lambdarank
shiyu1994 Sep 20, 2024
306af04
Merge branch 'master' into pairwise-ranking-dev
shiyu1994 Sep 20, 2024
b69913d
fix compilation error C3200 with visual studio
shiyu1994 Oct 10, 2024
6dba1cf
clean up main.cpp
shiyu1994 Oct 11, 2024
3b2e29d
Exposing configuration parameters for pairwise ranking
metpavel Oct 18, 2024
f1c32d3
fix bugs and pass by reference for SigmoidCache&
shiyu1994 Nov 8, 2024
51693e2
add pairing approach
shiyu1994 Nov 8, 2024
5071842
add at_least_one_relevant
shiyu1994 Nov 8, 2024
598764b
fix num bin for row wise in pairwise ranking
shiyu1994 Nov 21, 2024
f7deab4
save for debug
shiyu1994 Dec 17, 2024
0d1b310
update doc
shiyu1994 Dec 18, 2024
8f9ab26
add random_k pairing mode
shiyu1994 Feb 18, 2025
d797122
clean up code
shiyu1994 Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
save changes for debug
  • Loading branch information
shiyu1994 committed Jul 8, 2024
commit 0d572d7a550a0f113e8d392726a405530dc92cc0
2 changes: 2 additions & 0 deletions include/LightGBM/bin.h
Original file line number Diff line number Diff line change
@@ -537,6 +537,8 @@ class Bin {
virtual const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, std::vector<BinIterator*>* bin_iterator, const int num_threads) const = 0;

virtual const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, BinIterator** bin_iterator) const = 0;

int group_index_ = -1;
};


26 changes: 13 additions & 13 deletions include/LightGBM/feature_group.h
Original file line number Diff line number Diff line change
@@ -587,26 +587,26 @@ class FeatureGroup {
multi_bin_data_.clear();
for (int i = 0; i < num_feature_; ++i) {
int addi = bin_mappers_[i]->GetMostFreqBin() == 0 ? 0 : 1;
if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
multi_bin_data_.emplace_back(Bin::CreateSparseBin(
num_data, bin_mappers_[i]->num_bin() + addi));
} else {
// if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
// multi_bin_data_.emplace_back(Bin::CreateSparseBin(
// num_data, bin_mappers_[i]->num_bin() + addi));
// } else {
multi_bin_data_.emplace_back(
Bin::CreateDenseBin(num_data, bin_mappers_[i]->num_bin() + addi));
}
// }
}
is_multi_val_ = true;
} else {
if (force_sparse ||
(!force_dense && num_feature_ == 1 &&
bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
is_sparse_ = true;
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
// if (force_sparse ||
// (!force_dense && num_feature_ == 1 &&
// bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
// is_sparse_ = true;
// bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
// } else {
is_sparse_ = false;
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
is_multi_val_ = false;
// }
// is_multi_val_ = false;
}
}

5 changes: 3 additions & 2 deletions src/io/bin.cpp
Original file line number Diff line number Diff line change
@@ -707,9 +707,10 @@ namespace LightGBM {
// if (use_pairwise_ranking) {
Log::Warning("Pairwise ranking with sparse row-wse bins is not supported yet.");
return CreateMultiValDenseBin(num_data, num_bin, num_feature, offsets, use_pairwise_ranking, paired_ranking_item_global_index_map);
// } else {
// return CreateMultiValSparseBin(num_data, num_bin,
// average_element_per_row, use_pairwise_ranking, paired_ranking_item_global_index_map);
// }
return CreateMultiValSparseBin(num_data, num_bin,
average_element_per_row, use_pairwise_ranking, paired_ranking_item_global_index_map);
} else {
return CreateMultiValDenseBin(num_data, num_bin, num_feature, offsets, use_pairwise_ranking, paired_ranking_item_global_index_map);
}
14 changes: 8 additions & 6 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
@@ -499,6 +499,7 @@ void PushDataToMultiValBin(
MultiValBin* ret) {
Common::FunctionTimer fun_time("Dataset::PushDataToMultiValBin",
global_timer);
Log::Warning("num_data = %d", num_data);
if (ret->IsSparse()) {
// Log::Fatal("pairwise ranking with sparse multi val bin is not supported.");
Threading::For<data_size_t>(
@@ -634,11 +635,11 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of
1.0 - sum_dense_ratio);
if (use_pairwise_ranking) {

for (size_t i = 0; i < iters.size(); ++i) {
for (size_t j = 0; j < iters[i].size(); ++j) {
Log::Warning("i = %ld, j = %ld, iters[i][j] = %d", i, j, static_cast<int>(iters[i][j] == nullptr));
}
}
// for (size_t i = 0; i < iters.size(); ++i) {
// for (size_t j = 0; j < iters[i].size(); ++j) {
// Log::Warning("i = %ld, j = %ld, iters[i][j] = %d", i, j, static_cast<int>(iters[i][j] == nullptr));
// }
// }

const int num_original_features = static_cast<int>(most_freq_bins.size()) / 2;
std::vector<uint32_t> original_most_freq_bins;
@@ -662,7 +663,7 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of
ret.reset(MultiValBin::CreateMultiValBin(
num_original_data, original_offsets.back(), num_original_features,
1.0 - sum_dense_ratio, original_offsets, use_pairwise_ranking, metadata_.paired_ranking_item_global_index_map()));
PushDataToMultiValBin(num_original_features, original_most_freq_bins, original_offsets, &iters, ret.get());
PushDataToMultiValBin(num_original_data, original_most_freq_bins, original_offsets, &iters, ret.get());
} else {
ret.reset(MultiValBin::CreateMultiValBin(
num_data_, offsets.back(), static_cast<int>(most_freq_bins.size()),
@@ -1632,6 +1633,7 @@ void Dataset::ConstructHistogramsInner(
OMP_LOOP_EX_BEGIN();
int group = used_dense_group[gi];
const int num_bin = feature_groups_[group]->num_total_bin_;
feature_groups_[group]->bin_data_->group_index_ = gi;
if (USE_QUANT_GRAD) {
if (HIST_BITS == 16) {
auto data_ptr = reinterpret_cast<hist_t*>(reinterpret_cast<int32_t*>(hist_data) + group_bin_boundaries_[group]);
14 changes: 13 additions & 1 deletion src/io/multi_val_pairwise_lambdarank_bin.hpp
Original file line number Diff line number Diff line change
@@ -13,7 +13,9 @@ namespace LightGBM {
template <typename BIN_TYPE, template<typename> class MULTI_VAL_BIN_TYPE>
class MultiValPairwiseLambdarankBin : public MULTI_VAL_BIN_TYPE<BIN_TYPE> {
public:
MultiValPairwiseLambdarankBin(data_size_t num_data, int num_bin, int num_feature, const std::vector<uint32_t>& offsets): MULTI_VAL_BIN_TYPE<BIN_TYPE>(num_data, num_bin, num_feature, offsets) {}
MultiValPairwiseLambdarankBin(data_size_t num_data, int num_bin, int num_feature, const std::vector<uint32_t>& offsets): MULTI_VAL_BIN_TYPE<BIN_TYPE>(num_data, num_bin, num_feature, offsets) {
this->num_bin_ = num_bin * 2;
}
protected:
const std::pair<data_size_t, data_size_t>* paired_ranking_item_global_index_map_;
};
@@ -66,6 +68,13 @@ class MultiValDensePairwiseLambdarankBin: public MultiValPairwiseLambdarankBin<B
const score_t hessian = ORDERED ? hessians[i] : hessians[idx];
for (int j = 0; j < this->num_feature_; ++j) {
const uint32_t bin = static_cast<uint32_t>(first_data_ptr[j]);
// if (bin != 0) {
// Log::Warning("first bin = %d, num_feature_ = %d", bin, this->num_feature_);
// }
if (j == 0) {
Log::Warning("group index = %d bin = %d gradient = %f hessian = %f", j, bin, gradient, hessian);
}

const auto ti = (bin + this->offsets_[j]) << 1;
grad[ti] += gradient;
hess[ti] += hessian;
@@ -76,6 +85,9 @@ class MultiValDensePairwiseLambdarankBin: public MultiValPairwiseLambdarankBin<B
const auto base_offset = this->offsets_.back();
for (int j = 0; j < this->num_feature_; ++j) {
const uint32_t bin = static_cast<uint32_t>(second_data_ptr[j]);
// if (bin != 0) {
// Log::Warning("second bin = %d, num_feature_ = %d", bin, this->num_feature_);
// }
const auto ti = (bin + this->offsets_[j] + base_offset) << 1;
grad[ti] += gradient;
hess[ti] += hessian;
6 changes: 6 additions & 0 deletions src/io/pairwise_lambdarank_bin.cpp
Original file line number Diff line number Diff line change
@@ -98,6 +98,9 @@ void DensePairwiseRankingBin<VAL_T, IS_4BIT, ITERATOR_TYPE>::ConstructHistogramI
for (; i < pf_end; ++i) {
const auto paired_idx = USE_INDICES ? data_indices[i] : i;
const auto ti = GetBinAt(paired_idx) << 1;
if (this->group_index_ == 0) {
Log::Warning("group index = %d bin = %d gradient = %f hessian = %f", this->group_index_, ti / 2, ordered_gradients[i], ordered_hessians[i]);
}
if (USE_HESSIAN) {
grad[ti] += ordered_gradients[i];
hess[ti] += ordered_hessians[i];
@@ -110,6 +113,9 @@ void DensePairwiseRankingBin<VAL_T, IS_4BIT, ITERATOR_TYPE>::ConstructHistogramI
for (; i < end; ++i) {
const auto paired_idx = USE_INDICES ? data_indices[i] : i;
const auto ti = GetBinAt(paired_idx) << 1;
if (this->group_index_ == 0) {
Log::Warning("group index = %d bin = %d gradient = %f hessian = %f", this->group_index_, ti / 2, ordered_gradients[i], ordered_hessians[i]);
}
if (USE_HESSIAN) {
grad[ti] += ordered_gradients[i];
hess[ti] += ordered_hessians[i];
60 changes: 30 additions & 30 deletions src/io/pairwise_ranking_feature_group.cpp
Original file line number Diff line number Diff line change
@@ -35,43 +35,43 @@ void PairwiseRankingFeatureGroup::CreateBinData(int num_data, bool is_multi_val,
multi_bin_data_.clear();
for (int i = 0; i < num_feature_; ++i) {
int addi = bin_mappers_[i]->GetMostFreqBin() == 0 ? 0 : 1;
if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
if (is_first_or_second_in_pairing_ == 0) {
multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingFirstBin(
num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
} else {
multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingSecondBin(
num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
}
// if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
// if (is_first_or_second_in_pairing_ == 0) {
// multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingFirstBin(
// num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
// } else {
// multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingSecondBin(
// num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
// }
// } else {
if (is_first_or_second_in_pairing_ == 0) {
multi_bin_data_.emplace_back(
Bin::CreateDensePairwiseRankingFirstBin(num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
} else {
if (is_first_or_second_in_pairing_ == 0) {
multi_bin_data_.emplace_back(
Bin::CreateDensePairwiseRankingFirstBin(num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
} else {
multi_bin_data_.emplace_back(
Bin::CreateDensePairwiseRankingSecondBin(num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
}
multi_bin_data_.emplace_back(
Bin::CreateDensePairwiseRankingSecondBin(num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
}
// }
}
is_multi_val_ = true;
} else {
if (force_sparse ||
(!force_dense && num_feature_ == 1 &&
bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
is_sparse_ = true;
if (is_first_or_second_in_pairing_ == 0) {
bin_data_.reset(Bin::CreateSparsePairwiseRankingFirstBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
} else {
bin_data_.reset(Bin::CreateSparsePairwiseRankingSecondBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
}
// if (force_sparse ||
// (!force_dense && num_feature_ == 1 &&
// bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
// is_sparse_ = true;
// if (is_first_or_second_in_pairing_ == 0) {
// bin_data_.reset(Bin::CreateSparsePairwiseRankingFirstBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
// } else {
// bin_data_.reset(Bin::CreateSparsePairwiseRankingSecondBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
// }
// } else {
is_sparse_ = false;
if (is_first_or_second_in_pairing_ == 0) {
bin_data_.reset(Bin::CreateDensePairwiseRankingFirstBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
} else {
is_sparse_ = false;
if (is_first_or_second_in_pairing_ == 0) {
bin_data_.reset(Bin::CreateDensePairwiseRankingFirstBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
} else {
bin_data_.reset(Bin::CreateDensePairwiseRankingSecondBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
}
bin_data_.reset(Bin::CreateDensePairwiseRankingSecondBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
}
// }
is_multi_val_ = false;
}
}
1 change: 1 addition & 0 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
@@ -77,6 +77,7 @@ class SparseBin : public Bin {
explicit SparseBin(data_size_t num_data) : num_data_(num_data) {
int num_threads = OMP_NUM_THREADS();
push_buffers_.resize(num_threads);
Log::Warning("sparse bin is created !!!");
}

~SparseBin() {}
16 changes: 16 additions & 0 deletions src/treelearner/feature_histogram.hpp
Original file line number Diff line number Diff line change
@@ -20,6 +20,8 @@
#include "monotone_constraints.hpp"
#include "split_info.hpp"

#include <fstream>

namespace LightGBM {

class FeatureMetainfo {
@@ -1501,6 +1503,7 @@ class HistogramPool {
}
OMP_THROW_EX();
}
offsets_ = offsets;
}

void ResetConfig(const Dataset* train_data, const Config* config) {
@@ -1522,6 +1525,18 @@ class HistogramPool {
}
}

void DumpContent() const {
std::ofstream fout("historam_wise.txt");
int cur_offsets_ptr = 0;
for (int i = 0; i < data_[0].size() / 2; ++i) {
if (i == offsets_[cur_offsets_ptr]) {
fout << "offset " << cur_offsets_ptr << " " << offsets_[cur_offsets_ptr] << " " << feature_metas_[cur_offsets_ptr].num_bin << " " << static_cast<int>(feature_metas_[cur_offsets_ptr].offset) << std::endl;
++cur_offsets_ptr;
}
fout << i << " " << data_[0][2 * i] << " " << data_[0][2 * i + 1] << std::endl;
}
}

/*!
* \brief Get data for the specific index
* \param idx which index want to get
@@ -1591,6 +1606,7 @@ class HistogramPool {
std::vector<int> inverse_mapper_;
std::vector<int> last_used_time_;
int cur_time_ = 0;
std::vector<uint32_t> offsets_;
};

} // namespace LightGBM
25 changes: 17 additions & 8 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
@@ -756,6 +756,9 @@ std::set<int> SerialTreeLearner::FindAllForceFeatures(Json force_split_leaf_sett
void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf, bool update_cnt) {
Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);

histogram_pool_.DumpContent();

SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
const int inner_feature_index =
train_data_->InnerFeatureIndex(best_split_info.feature);
@@ -843,7 +846,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
// init the leaves that used on next iteration
if (!config_->use_quantized_grad) {
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
// CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
@@ -853,7 +856,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_info.right_sum_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
// CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
@@ -865,7 +868,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
}
} else {
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
// CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
@@ -877,7 +880,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
// CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
@@ -896,9 +899,9 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
data_partition_->leaf_count(*right_leaf));
}

#ifdef DEBUG
// #ifdef DEBUG
CheckSplit(best_split_info, *left_leaf, *right_leaf);
#endif
// #endif

auto leaves_need_update = constraints_->Update(
is_numerical_split, *left_leaf, *right_leaf,
@@ -1057,7 +1060,7 @@ std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
*split = bests[best_idx];
}

#ifdef DEBUG
// #ifdef DEBUG
void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) {
data_size_t num_data_in_left = 0;
data_size_t num_data_in_right = 0;
@@ -1114,6 +1117,12 @@ void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int l
sum_right_gradient += gradients_[index];
sum_right_hessian += hessians_[index];
}
Log::Warning("num_data_in_left = %d, best_split_info.left_count = %d", num_data_in_left, best_split_info.left_count);
Log::Warning("num_data_in_right = %d, best_split_info.right_count = %d", num_data_in_right, best_split_info.right_count);
Log::Warning("sum_left_gradient = %f, best_split_info.left_sum_gradient = %f", sum_left_gradient, best_split_info.left_sum_gradient);
Log::Warning("sum_left_hessian = %f, best_split_info.sum_left_hessian = %f", sum_left_hessian, best_split_info.left_sum_hessian);
Log::Warning("sum_right_gradient = %f, best_split_info.sum_right_gradient = %f", sum_right_gradient, best_split_info.right_sum_gradient);
Log::Warning("sum_right_hessian = %f, best_split_info.sum_right_hessian = %f", sum_right_hessian, best_split_info.right_sum_hessian);
CHECK_EQ(num_data_in_left, best_split_info.left_count);
CHECK_EQ(num_data_in_right, best_split_info.right_count);
CHECK_LE(std::fabs(sum_left_gradient - best_split_info.left_sum_gradient), 1e-3);
@@ -1123,6 +1132,6 @@ void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int l
Log::Warning("============================ pass split check ============================");
}
}
#endif
// #endif

} // namespace LightGBM
4 changes: 2 additions & 2 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
@@ -171,9 +171,9 @@ class SerialTreeLearner: public TreeLearner {

std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);

#ifdef DEBUG
// #ifdef DEBUG
void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index);
#endif
// #endif

/*!
* \brief Get the number of data in a leaf