Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add chosen metric argument to clarify early stopping behaviour #6424

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
refactor: replace argument name and add chekcs +warning
  • Loading branch information
sami-ka committed Apr 26, 2024
commit 44fcae278557b5793bf96bb763f40f0a0136a613
54 changes: 32 additions & 22 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -279,15 +279,20 @@ def __init__(
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
chosen_metric: str = None,
metric_name: Optional[str] = None,
) -> None:
self.enabled = _should_enable_early_stopping(stopping_rounds)

# Test if both parameters are used
if (first_metric_only + (chosen_metric is not None)) == 2:
error_message = """
Only one of first_metric_only and chosen_metric parameters should be used"""
raise ValueError(error_message)
if first_metric_only and (metric_name is not None):
error_msg = """
Only one of 'first_metric_only' and 'chosen_metric' should be used"""
raise ValueError(error_msg)

# If metric_name is used, min_delta must be a scalar
if isinstance(min_delta, list) and (metric_name is not None):
error_msg = "Use a scalar value for 'min_delta' when using 'chosen_metric'."
raise ValueError(error_msg)

self.order = 30
self.before_iteration = False
@@ -296,7 +301,7 @@ def __init__(
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta
self.chosen_metric = chosen_metric
self.metric_name = metric_name

self._reset_storages()

@@ -353,13 +358,13 @@ def _init(self, env: CallbackEnv) -> None:

self._reset_storages()

list_metrics = {m[1] for m in env.evaluation_result_list}
if (self.chosen_metric is not None) and (self.chosen_metric not in list_metrics):
error_message = f"""Chosen callback metric: {self.chosen_metric} is not in the evaluation list.
The list of available metrics for early stopping is: {list_metrics}."""
set_metrics = {m[1] for m in env.evaluation_result_list}
if (self.metric_name is not None) and (self.metric_name not in set_metrics):
error_message = f"""Chosen callback metric:{self.metric_name} is not in the evaluation list.
The set of available metrics for early stopping is : {set_metrics}."""
raise ValueError(error_message)

n_metrics = len(list_metrics)
n_metrics = len(set_metrics)
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(self.min_delta, list):
if not all(t >= 0 for t in self.min_delta):
@@ -377,14 +382,11 @@ def _init(self, env: CallbackEnv) -> None:
raise ValueError("Must provide a single value for min_delta or as many as metrics.")
if self.first_metric_only and self.verbose:
_log_info(f"Using only {self.min_delta[0]} as early stopping min_delta.")
if (self.chosen_metric is not None) and self.verbose:
index_chosen_metric = list_metrics.index(self.chosen_metric)
_log_info(f"Using only {self.min_delta[index_chosen_metric]} as early stopping min_delta.")
deltas = self.min_delta * n_datasets
else:
if self.min_delta < 0:
raise ValueError("Early stopping min_delta must be non-negative.")
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and (self.index_chosen_metric is None) and self.verbose:
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and (self.metric_name is None) and self.verbose:
_log_info(f"Using {self.min_delta} as min_delta for all metrics.")
deltas = [self.min_delta] * n_datasets * n_metrics

@@ -408,8 +410,8 @@ def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str]
)
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
if self.chosen_metric is not None:
_log_info(f"Evaluated only: {self.chosen_metric}")
if self.metric_name is not None:
_log_info(f"Evaluated only: {self.metric_name}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])

def __call__(self, env: CallbackEnv) -> None:
@@ -437,7 +439,7 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if (self.chosen_metric is not None) and self.chosen_metric != eval_name_splitted[-1]:
if (self.metric_name is not None) and self.metric_name != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
@@ -453,8 +455,8 @@ def __call__(self, env: CallbackEnv) -> None:
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
if self.chosen_metric is not None:
_log_info(f"Evaluated only: {self.chosen_metric}")
if self.metric_name is not None:
_log_info(f"Evaluated only: {self.metric_name}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
self._final_iteration_check(env, eval_name_splitted, i)

@@ -476,7 +478,7 @@ def early_stopping(
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
chosen_metric: str = None,
metric_name: Optional[str] = None,
) -> _EarlyStoppingCallback:
"""Create a callback that activates early stopping.

@@ -511,10 +513,18 @@ def early_stopping(
callback : _EarlyStoppingCallback
The callback that activates early stopping.
"""

if first_metric_only:
warning_message = """
'first_metric_only' parameter is deprecated.
It will be removed in a future release of lightgbm.
"""
_log_warning(warning_message)

return _EarlyStoppingCallback(
stopping_rounds=stopping_rounds,
first_metric_only=first_metric_only,
verbose=verbose,
min_delta=min_delta,
chosen_metric=chosen_metric
metric_name=metric_name
)
Loading
Oops, something went wrong.