Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to specifying position using position_column parameter #6825

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add position_column support
  • Loading branch information
NProkoptsev committed Feb 14, 2025
commit 43631b873d973092818d7e5e540265e3fe08e986
1 change: 1 addition & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
, "two_round"
, "use_missing"
, "weight_column"
, "position_column"
, "zero_as_missing"
)])
}
1 change: 1 addition & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
@@ -1076,6 +1076,7 @@ test_that("all parameters are stored correctly with save_model_to_string()", {
, "[label_column: ]"
, "[weight_column: ]"
, "[group_column: ]"
, "[position_column: ]"
, "[ignore_column: ]"
, "[categorical_feature: ]"
, "[forcedbins_filename: ]"
12 changes: 12 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
@@ -925,6 +925,18 @@ Dataset Parameters

- **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and query\_id is column\_1, the correct parameter is ``query=0``

- ``position_column`` :raw-html:`<a id="group_column" title="Permalink to this parameter" href="#group_column">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = int or string, aliases: ``position``, ``position_id``

- used to specify the position id column

- use number for index, e.g. ``position=0`` means column\_0 is the position

- add a prefix ``name:`` for column name, e.g. ``position=name:position_id``

- **Note**: works only in case of loading data directly from text file

- **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and position\_id is column\_1, the correct parameter is ``position=0``

- ``ignore_column`` :raw-html:`<a id="ignore_column" title="Permalink to this parameter" href="#ignore_column">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = multi-int or string, aliases: ``ignore_feature``, ``blacklist``

- used to specify some ignoring columns in training
6 changes: 6 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
@@ -169,6 +169,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \param has_weights Whether the dataset has Metadata weights
* \param has_init_scores Whether the dataset has Metadata initial scores
* \param has_queries Whether the dataset has Metadata queries/groups
* \param has_positions Whether the dataset has Metadata positions/groups
* \param nclasses Number of initial score classes
* \param nthreads Number of external threads that will use the PushRows APIs
* \param omp_max_threads Maximum number of OpenMP threads (-1 for default)
@@ -178,6 +179,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t has_positions,
int32_t nclasses,
int32_t nthreads,
int32_t omp_max_threads);
@@ -233,6 +235,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
* \param weight Optional pointer to array with nrow weights
* \param init_score Optional pointer to array with nrow*nclasses initial scores, in column format
* \param query Optional pointer to array with nrow query values
* \param position Optional pointer to array with nrow position values
* \param tid The id of the calling thread, from 0...N-1 threads
* \return 0 when succeed, -1 when failure happens
*/
@@ -246,6 +249,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
const float* weight,
const double* init_score,
const int32_t* query,
const int32_t* position,
int32_t tid);

/*!
@@ -288,6 +292,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
* \param weight Optional pointer to array with nindptr-1 weights
* \param init_score Optional pointer to array with (nindptr-1)*nclasses initial scores, in column format
* \param query Optional pointer to array with nindptr-1 query values
* \param position Optional pointer to array with nindptr-1 position values
* \param tid The id of the calling thread, from 0...N-1 threads
* \return 0 when succeed, -1 when failure happens
*/
@@ -304,6 +309,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle datase
const float* weight,
const double* init_score,
const int32_t* query,
const int32_t* position,
int32_t tid);

/*!
9 changes: 9 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
@@ -766,6 +766,15 @@ struct Config {
// desc = **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and query\_id is column\_1, the correct parameter is ``query=0``
std::string group_column = "";

// type = int or string
// alias = position, position_id, position_column
// desc = used to specify the position/position id column
// desc = use number for index, e.g. ``position=0`` means column\_0 is the position id
// desc = add a prefix ``name:`` for column name, e.g. ``position=name:position_id``
// desc = **Note**: works only in case of loading data directly from text file
// desc = **Note**: index starts from ``0`` and it doesn't count the label column when passing type is ``int``, e.g. when label is column\_0 and position\_id is column\_1, the correct parameter is ``position=0``
std::string position_column = "";

// type = multi-int or string
// alias = ignore_feature, blacklist
// desc = used to specify some ignoring columns in training
50 changes: 29 additions & 21 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
@@ -76,8 +76,9 @@ class Metadata {
* \param num_data Number of training data
* \param weight_idx Index of weight column, < 0 means doesn't exists
* \param query_idx Index of query id column, < 0 means doesn't exists
* \param position_idx Index of position id column, < 0 means doesn't exists
*/
void Init(data_size_t num_data, int weight_idx, int query_idx);
void Init(data_size_t num_data, int weight_idx, int query_idx, int position_idx);

/*!
* \brief Allocate space for label, weight (if exists), initial score (if exists) and query (if exists)
@@ -92,9 +93,10 @@ class Metadata {
* \param has_weights Whether the metadata has weights
* \param has_init_scores Whether the metadata has initial scores
* \param has_queries Whether the metadata has queries
* \param has_positions Whether the metadata has positions
* \param nclasses Number of classes for initial scores
*/
void Init(data_size_t num_data, int32_t has_weights, int32_t has_init_scores, int32_t has_queries, int32_t nclasses);
void Init(data_size_t num_data, int32_t has_weights, int32_t has_init_scores, int32_t has_queries, int32_t has_positions, int32_t nclasses);

/*!
* \brief Partition label by used indices
@@ -120,6 +122,7 @@ class Metadata {
void SetQuery(const ArrowChunkedArray& array);

void SetPosition(const data_size_t* position, data_size_t len);
void SetPosition(const ArrowChunkedArray& array);

/*!
* \brief Set initial scores
@@ -186,6 +189,15 @@ class Metadata {
queries_[idx] = static_cast<data_size_t>(value);
}

/*!
* \brief Set Position Id for one record
* \param idx Index of this record
* \param value Position Id value of this record
*/
inline void SetPositionAt(data_size_t idx, data_size_t value) {
positions_[idx] = static_cast<data_size_t>(value);
}

/*! \brief Load initial scores from file */
void LoadInitialScore(const std::string& data_filename);

@@ -197,13 +209,15 @@ class Metadata {
* \param weights Pointer to weight data, or null
* \param init_scores Pointer to init-score data, or null
* \param queries Pointer to query data, or null
* \param positions Pointer to position data, or null
*/
void InsertAt(data_size_t start_index,
data_size_t count,
const float* labels,
const float* weights,
const double* init_scores,
const int32_t* queries);
const int32_t* queries,
const int32_t* positions);

/*!
* \brief Perform any extra operations after all data has been loaded
@@ -233,24 +247,13 @@ class Metadata {
}
}

/*!
* \brief Get position IDs, if does not exist then return nullptr
* \return Pointer of position IDs
*/
inline const std::string* position_ids() const {
if (!position_ids_.empty()) {
return position_ids_.data();
} else {
return nullptr;
}
}

/*!
* \brief Get Number of different position IDs
* \return number of different position IDs
*/
inline size_t num_position_ids() const {
return position_ids_.size();
size_t max = *std::max_element(positions_.begin(), positions_.end());
return max + 1;
}

/*!
@@ -354,6 +357,11 @@ class Metadata {
void SetInitScoresFromIterator(It first, It last);
/*! \brief Insert queries at the given index */
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
/*! \brief Set positions from pointers to the first element and the end of an iterator. */
template <typename It>
void SetPositionsFromIterator(It first, It last);
/*! \brief Insert positions at the given index */
void InsertPositions(const data_size_t* positions, data_size_t start_index, data_size_t len);
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
template <typename It>
void SetQueriesFromIterator(It first, It last);
@@ -371,8 +379,6 @@ class Metadata {
std::vector<label_t> weights_;
/*! \brief Positions data */
std::vector<data_size_t> positions_;
/*! \brief Position identifiers */
std::vector<std::string> position_ids_;
/*! \brief Query boundaries */
std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */
@@ -519,6 +525,7 @@ class Dataset {
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t has_positions,
int32_t nclasses,
int32_t nthreads,
int32_t omp_max_threads) {
@@ -529,7 +536,7 @@ class Dataset {
omp_max_threads_ = OMP_NUM_THREADS();
}

metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses);
metadata_.Init(num_data, has_weights, has_init_scores, has_queries, has_positions, nclasses);
for (int i = 0; i < num_groups_; ++i) {
feature_groups_[i]->InitStreaming(nthreads, omp_max_threads_);
}
@@ -623,8 +630,9 @@ class Dataset {
const label_t* labels,
const label_t* weights,
const double* init_scores,
const data_size_t* queries) {
metadata_.InsertAt(start_index, count, labels, weights, init_scores, queries);
const data_size_t* queries,
const data_size_t* positions) {
metadata_.InsertAt(start_index, count, labels, weights, init_scores, queries, positions);
}

inline int RealFeatureIndex(int fidx) const {
2 changes: 2 additions & 0 deletions include/LightGBM/dataset_loader.h
Original file line number Diff line number Diff line change
@@ -95,6 +95,8 @@ class DatasetLoader {
int weight_idx_;
/*! \brief index of group column */
int group_idx_;
/*! \brief index of position column */
int position_idx_;
/*! \brief Mapper from real feature index to used index*/
std::unordered_set<int> ignore_features_;
/*! \brief store feature names */
2 changes: 2 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -2042,7 +2042,9 @@ def get_params(self) -> Dict[str, Any]:
"two_round",
"use_missing",
"weight_column",
"position_column",
"zero_as_missing",
"position_column"
)
return {k: v for k, v in self.params.items() if k in dataset_params}
else:
14 changes: 11 additions & 3 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
@@ -323,6 +323,11 @@ class Booster {
Log::Fatal(
"Cannot change group_column after constructed Dataset handle.");
}
if (new_param.count("position_column") &&
new_config.position_column != old_config.position_column) {
Log::Fatal(
"Cannot change position_column after constructed Dataset handle.");
}
if (new_param.count("ignore_column") &&
new_config.ignore_column != old_config.ignore_column) {
Log::Fatal(
@@ -1114,13 +1119,14 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t has_positions,
int32_t nclasses,
int32_t nthreads,
int32_t omp_max_threads) {
API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto num_data = p_dataset->num_data();
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads, omp_max_threads);
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, has_positions, nclasses, nthreads, omp_max_threads);
p_dataset->set_wait_for_manual_finish(true);
API_END();
}
@@ -1163,6 +1169,7 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
const float* weights,
const double* init_scores,
const int32_t* queries,
const int32_t* positions,
int32_t tid) {
API_BEGIN();
#ifdef LABEL_T_USE_DOUBLE
@@ -1191,7 +1198,7 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
}
OMP_THROW_EX();

p_dataset->InsertMetadataAt(start_row, nrow, labels, weights, init_scores, queries);
p_dataset->InsertMetadataAt(start_row, nrow, labels, weights, init_scores, queries, positions);

if (!p_dataset->wait_for_manual_finish() && (start_row + nrow == p_dataset->num_data())) {
p_dataset->FinishLoad();
@@ -1245,6 +1252,7 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
const float* weights,
const double* init_scores,
const int32_t* queries,
const int32_t* positions,
int32_t tid) {
API_BEGIN();
#ifdef LABEL_T_USE_DOUBLE
@@ -1274,7 +1282,7 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
}
OMP_THROW_EX();

p_dataset->InsertMetadataAt(static_cast<int32_t>(start_row), nrow, labels, weights, init_scores, queries);
p_dataset->InsertMetadataAt(static_cast<int32_t>(start_row), nrow, labels, weights, init_scores, queries, positions);

if (!p_dataset->wait_for_manual_finish() && (start_row + nrow == static_cast<int64_t>(p_dataset->num_data()))) {
p_dataset->FinishLoad();
8 changes: 8 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
@@ -133,6 +133,8 @@ const std::unordered_map<std::string, std::string>& Config::alias_table() {
{"query_column", "group_column"},
{"query", "group_column"},
{"query_id", "group_column"},
{"position", "position_column"},
{"position_id", "position_column"},
{"ignore_feature", "ignore_column"},
{"blacklist", "ignore_column"},
{"cat_feature", "categorical_feature"},
@@ -274,6 +276,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"label_column",
"weight_column",
"group_column",
"position_column",
"ignore_column",
"categorical_feature",
"forcedbins_filename",
@@ -552,6 +555,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetString(params, "group_column", &group_column);

GetString(params, "position_column", &position_column);

GetString(params, "ignore_column", &ignore_column);

GetString(params, "categorical_feature", &categorical_feature);
@@ -754,6 +759,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[label_column: " << label_column << "]\n";
str_buf << "[weight_column: " << weight_column << "]\n";
str_buf << "[group_column: " << group_column << "]\n";
str_buf << "[position_column: " << position_column << "]\n";
str_buf << "[ignore_column: " << ignore_column << "]\n";
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
@@ -883,6 +889,7 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
{"label_column", {"label"}},
{"weight_column", {"weight"}},
{"group_column", {"group", "group_id", "query_column", "query", "query_id"}},
{"position_column", {"position", "position_id"}},
{"ignore_column", {"ignore_feature", "blacklist"}},
{"categorical_feature", {"cat_feature", "categorical_column", "cat_column", "categorical_features"}},
{"forcedbins_filename", {}},
@@ -1028,6 +1035,7 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
{"label_column", "string"},
{"weight_column", "string"},
{"group_column", "string"},
{"position_column", "string"},
{"ignore_column", "vector<int>"},
{"categorical_feature", "vector<int>"},
{"forcedbins_filename", "string"},
2 changes: 1 addition & 1 deletion src/io/dataset.cpp
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ Dataset::Dataset(data_size_t num_data) {
CHECK_GT(num_data, 0);
data_filename_ = "noname";
num_data_ = num_data;
metadata_.Init(num_data_, NO_SPECIFIC, NO_SPECIFIC);
metadata_.Init(num_data_, NO_SPECIFIC, NO_SPECIFIC, NO_SPECIFIC);
is_finish_load_ = false;
wait_for_manual_finish_ = false;
group_bin_boundaries_.push_back(0);
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.