Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add chosen metric argument to clarify early stopping behaviour #6424

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add chosen_metric attribute in early stopping callback class
  • Loading branch information
sami-ka committed Apr 20, 2024
commit 4ce23bd92eb12825cd1f513c112411a73fc9f6e4
29 changes: 27 additions & 2 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -279,16 +279,24 @@ def __init__(
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
chosen_metric: str = None,
) -> None:
self.enabled = _should_enable_early_stopping(stopping_rounds)

# Test if both parameters are used
if (first_metric_only + (chosen_metric is not None)) == 2:
error_message = """
Only one of first_metric_only and chosen_metric parameters should be used"""
raise ValueError(error_message)

self.order = 30
self.before_iteration = False

self.stopping_rounds = stopping_rounds
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta
self.chosen_metric = chosen_metric

self._reset_storages()

@@ -345,7 +353,13 @@ def _init(self, env: CallbackEnv) -> None:

self._reset_storages()

n_metrics = len({m[1] for m in env.evaluation_result_list})
list_metrics = {m[1] for m in env.evaluation_result_list}
if (self.chosen_metric is not None) and (self.chosen_metric not in list_metrics):
error_message = f"""Chosen callback metric: {self.chosen_metric} is not in the evaluation list.
The list of available metrics for early stopping is: {list_metrics}."""
raise ValueError(error_message)

n_metrics = len(list_metrics)
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(self.min_delta, list):
if not all(t >= 0 for t in self.min_delta):
@@ -363,11 +377,14 @@ def _init(self, env: CallbackEnv) -> None:
raise ValueError("Must provide a single value for min_delta or as many as metrics.")
if self.first_metric_only and self.verbose:
_log_info(f"Using only {self.min_delta[0]} as early stopping min_delta.")
if (self.chosen_metric is not None) and self.verbose:
index_chosen_metric = list_metrics.index(self.chosen_metric)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works, list_metrics is not actually a list 👀

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into it.
Following @jameslamb comment, this part of the code will be impacted so I will need to rewrite it anyway

_log_info(f"Using only {self.min_delta[index_chosen_metric]} as early stopping min_delta.")
deltas = self.min_delta * n_datasets
else:
if self.min_delta < 0:
raise ValueError("Early stopping min_delta must be non-negative.")
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and (self.index_chosen_metric is None) and self.verbose:
_log_info(f"Using {self.min_delta} as min_delta for all metrics.")
deltas = [self.min_delta] * n_datasets * n_metrics

@@ -391,6 +408,8 @@ def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str]
)
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
if self.chosen_metric is not None:
_log_info(f"Evaluated only: {self.chosen_metric}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])

def __call__(self, env: CallbackEnv) -> None:
@@ -418,6 +437,8 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if (self.chosen_metric is not None) and self.chosen_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
@@ -432,6 +453,8 @@ def __call__(self, env: CallbackEnv) -> None:
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
if self.chosen_metric is not None:
_log_info(f"Evaluated only: {self.chosen_metric}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
self._final_iteration_check(env, eval_name_splitted, i)

@@ -453,6 +476,7 @@ def early_stopping(
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
chosen_metric: str = None,
) -> _EarlyStoppingCallback:
"""Create a callback that activates early stopping.

@@ -492,4 +516,5 @@ def early_stopping(
first_metric_only=first_metric_only,
verbose=verbose,
min_delta=min_delta,
chosen_metric=chosen_metric
)