Skip to content

Commit

Permalink
Use callback in LGBM.train. (microsoft#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahungTai committed Mar 13, 2022
1 parent c716437 commit 40a6bd1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
14 changes: 8 additions & 6 deletions qlib/contrib/model/gbdt.py
Expand Up @@ -68,17 +68,18 @@ def fit(
evals_result = {} # in case of unsafety of Python default values
ds_l = self._prepare_data(dataset, reweighter)
ds, names = list(zip(*ds_l))
early_stopping_callback = lgb.early_stopping(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
self.model = lgb.train(
self.params,
ds[0], # training dataset
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
valid_sets=ds,
valid_names=names,
early_stopping_rounds=(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
),
verbose_eval=verbose_eval,
evals_result=evals_result,
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
**kwargs,
)
for k in names:
Expand Down Expand Up @@ -110,12 +111,13 @@ def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20, rewei
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
if dtrain.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
callbacks=[verbose_eval_callback],
)
14 changes: 9 additions & 5 deletions qlib/contrib/model/highfreq_gdbt_model.py
Expand Up @@ -110,18 +110,21 @@ def fit(
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
evals_result=None,
):
if evals_result is None:
evals_result = dict()
dtrain, dvalid = self._prepare_data(dataset)
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
Expand All @@ -147,12 +150,13 @@ def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
callbacks=[verbose_eval_callback],
)

0 comments on commit 40a6bd1

Please sign in to comment.