New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Time-based pruning #2873
Comments
This is the best I could do for now: class TimeBasedPruner(BasePruner):
"""
Prunes a trial if another trial achieved a better objective value using the same or less time.
"""
_logger = optuna.logging.get_logger(__name__)
def __init__(self, min_time_elapsed: timedelta = 0) -> None:
"""
:param min_time_elapsed: the minimum amount of time that a trial must run for before being
considered for prunning
"""
self.min_time_elapsed = min_time_elapsed
@staticmethod
def _get_best_intermediate_result_over_steps(trial: FrozenTrial,
direction: StudyDirection) -> ndarray:
values = np.asarray(list(trial.intermediate_values.values()), dtype=float)
if direction == StudyDirection.MAXIMIZE:
return np.nanmax(values)
return np.nanmin(values)
def prune(self, study: Study, trial: FrozenTrial) -> bool:
time_elapsed = datetime.now() - trial.datetime_start
if time_elapsed < self.min_time_elapsed:
return False
trials = study.get_trials(deepcopy=False)
direction = study.direction
trial_best = self._get_best_intermediate_result_over_steps(trial, direction)
better_trial = None
better_symbol = None
for other in trials:
if other.state != TrialState.COMPLETE or other.duration > time_elapsed:
continue
other_best = self._get_best_intermediate_result_over_steps(other, direction)
if direction == StudyDirection.MAXIMIZE:
if other_best > trial_best:
better_trial = other
better_symbol = ">"
break
else:
if other_best < trial_best:
better_trial = other
better_symbol = "<"
break
if better_trial is not None:
self._logger.info(f"Pruning trial {trial.number} because trial {better_trial.number} had a "
f"better objective value ({better_trial.value} {better_symbol} {trial_best}) "
f"after duration {better_trial.duration}")
return True
return False Currently I can only compare against the full duration of other trials. Ideally, I'd like to be able to look up the run times of |
Thank you for your proposal! By the way, we can store such information by using def objective(trial):
# sampling hyper-parameters
trial.set_user_attr("computing_time_per_step", [])
for step in range(epochs):
# some training code
# computing `computing_time_per_step`
trial.set_user_attr("computing_time_per_step", trial.user_attrs["computing_time_per_step"] + [computing_time_per_step])
trial.report(score, step)
if trial.should_prune():
raise optuna.TrialPruned()
return score |
@nzw0301 Thank you. Is there any way to get the current model's global step from within a pruner? It doesn't look like pruners have access to the model. As well, I don't actually want to save this state when the trial gets saved to disk. I need it to evaluate whether pruning should take place. |
Is this the only way to go? class PyTorchLightningPruningWithTimeCallback(PyTorchLightningPruningCallback):
"""
A PyTorchLightningPruningCallback that sets "train_duration" for each global step
"""
def __init__(self, trial: Trial, monitor: str):
super().__init__(trial, monitor)
self._trial.user_attrs["global_step"] = 0
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
stage: Optional[str] = None) -> None:
self.start_time = datetime.now()
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
train_durations = self._trial.user_attrs.get("train_durations")
if train_durations is None:
train_durations = {}
train_durations[trainer.global_step] = (datetime.now() - self.start_time).total_seconds()
self._trial.set_user_attr("train_durations", train_durations)
self._trial.set_user_attr("global_step", pl_module.global_step)
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
del self._trial.user_attrs["global_step"] |
Is there any way to get the current model's global step from within a pruner?I think no. To access the model's parameter, codeI'm not familiar with pytorch-lightning but the code looks nice to me. ProposalThe total number of intermediate values in |
What is the best way to remove an intermediate value prior to the trial being saved to disk? Is the above |
Here is what I've got so far: class PyTorchLightningPruningWithTimeCallback(PyTorchLightningPruningCallback):
"""
A PyTorchLightningPruningCallback that sets "train_duration" for each global step
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` or
``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
``pytorch_lightning.LightningModule.training_step`` or
``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
how this dictionary is formatted.
"""
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
super().__init__(trial, monitor)
self._trial = trial
self.monitor = monitor
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
self._trial.set_user_attr("tensorboard_version", trainer.logger.version)
self.start_time = datetime.now()
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
epoch_to_time_elapsed = self._trial.user_attrs.get("epoch_to_time_elapsed", {})
current_epoch = trainer.current_epoch
epoch_to_time_elapsed[current_epoch] = (datetime.now() - self.start_time).total_seconds()
self._trial.set_user_attr("epoch_to_time_elapsed", epoch_to_time_elapsed)
self._trial.set_user_attr("current_epoch", current_epoch)
super(PyTorchLightningPruningWithTimeCallback, self).on_validation_end(trainer, pl_module)
def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
del self._trial.user_attrs["current_epoch"]
class TimeBasedPruner(BasePruner):
"""
Prunes a trial if another trial achieved a better objective value using the same or less time.
"""
_logger = optuna.logging.get_logger(__name__)
def __init__(self, min_duration: timedelta = 0, max_duration: timedelta = timedelta.max,
patience: timedelta = 0) -> None:
"""
:param min_duration: the minimum amount of time that a trial may run
:param max_duration: the maximum amount of time that a trial may run
:param patience: the amount of time a trial has to match a superior objective score held by other
trials
"""
self.min_time_elapsed = min_duration
self.max_time_elapsed = max_duration
self.patience = patience
self.trial_to_other_to_epoch_processed = {}
@staticmethod
def get_stop_epoch(trial: FrozenTrial, time_elapsed: timedelta, since_epoch: int) -> int:
"""
Returns the epoch completed after `time_elapsed`.
:param trial: a trial
:param time_elapsed: a relative time from the beginning of the trial
:param since_epoch: the first epoch to search from
:return: `len(trial.intermediate_values.values())` if no match was found
"""
epoch_to_time_elapsed: dict = trial.user_attrs["epoch_to_time_elapsed"]
max_epoch: int = len(trial.intermediate_values.values())
for epoch, completion_time_elapsed in itertools.islice(epoch_to_time_elapsed.items(), since_epoch,
max_epoch):
# Optuna invokes json.dumps() which converts dictionary keys to str
result = int(epoch)
completion_time_elapsed = timedelta(seconds=completion_time_elapsed)
if completion_time_elapsed > time_elapsed:
return result
return max_epoch
def get_pruning_context(self, other: FrozenTrial, other_start_epoch: int, other_stop_epoch: int,
trial: FrozenTrial, trial_start_epoch: int,
best_value: Callable[[Iterable[float]], float]) -> Optional[
tuple[float, timedelta, float]]:
"""
Evaluates whether `other` performed better than `trial`.
:param other: the other trial
:param other_start_epoch: the index to start searching at (inclusive)
:param other_stop_epoch: the index to stop searching at (exclusive)
:param trial: the current trial
:param trial_start_epoch: the index to start searching `trial` at
:param best_value: a function that maps an sequence of best objective values and returns the best one
:return: None if `other` performed worse than `trial`; otherwise, the better trial's objective score,
the time associated with the score, and the current trial's best score
"""
other_epoch_to_time_elapsed: dict = other.user_attrs["epoch_to_time_elapsed"]
for other_epoch in range(other_start_epoch, other_stop_epoch):
if other_epoch not in other.intermediate_values:
continue
# Optuna invokes json.dumps() which converts dictionary keys to str
other_time_elapsed = timedelta(seconds=other_epoch_to_time_elapsed[str(other_epoch)])
trial_stop_epoch = self.get_stop_epoch(trial, other_time_elapsed + self.patience,
trial_start_epoch)
trial_candidates = list(itertools.islice(trial.intermediate_values.values(), trial_start_epoch,
trial_stop_epoch))
if len(trial_candidates) == 0:
continue
other_value = other.intermediate_values[other_epoch]
trial_best_value = best_value(trial_candidates)
if best_value([other_value, trial_best_value]) != trial_best_value:
return other_value, other_time_elapsed, trial_best_value
return None
def prune(self, study: Study, trial: FrozenTrial) -> bool:
# If validation not invoked after every training step, user_attrs might not be set during the first
# couple steps.
epoch_to_time_elapsed: dict = trial.user_attrs.get("epoch_to_time_elapsed", {})
if len(epoch_to_time_elapsed) == 0:
return False
current_epoch = trial.user_attrs["current_epoch"]
# Optuna invokes json.dumps() which converts dictionary keys to str
time_elapsed = timedelta(seconds=epoch_to_time_elapsed[str(current_epoch)])
if time_elapsed < self.min_time_elapsed or time_elapsed < self.patience:
return False
if time_elapsed > self.max_time_elapsed:
return True
trial_to_start_epoch = trial.user_attrs.get("trial_to_start_epoch", {})
trial_start_epoch = trial_to_start_epoch.get(trial.number, 0)
time_elapsed_before_patience = time_elapsed - self.patience
if study.direction == StudyDirection.MAXIMIZE:
best_value = max
better_symbol = ">"
else:
best_value = min
better_symbol = "<"
trials = study.get_trials(deepcopy=False)
for other in trials:
if other.state != TrialState.COMPLETE:
continue
if len(other.intermediate_values) == 0:
continue
other_start_epoch = trial_to_start_epoch.get(other.number, 0)
other_stop_epoch = self.get_stop_epoch(other, time_elapsed_before_patience,
other_start_epoch)
context = self.get_pruning_context(other, other_start_epoch, other_stop_epoch, trial,
trial_start_epoch, best_value)
trial_to_start_epoch[other.number] = other_stop_epoch
trial.set_user_attr("trial_to_start_epoch", trial_to_start_epoch)
if context is not None:
other_value, other_time_elapsed, trial_best_value = context
self._logger.info(f"Pruning trial {trial.number} after {time_elapsed} because trial "
f"{other.number} had a better objective value "
f"({other_value} {better_symbol} {trial_best_value}) after "
f"{other_time_elapsed}")
return True
trial_stop_epoch = self.get_stop_epoch(trial, time_elapsed_before_patience, trial_start_epoch)
trial_to_start_epoch[trial.number] = trial_stop_epoch
trial.set_user_attr("trial_to_start_epoch", trial_to_start_epoch)
return False This allows you to abort trials if another trial has a better score by the current epoch. You can configure So for example you could say: "Run trials for a maximum of 1 hour each. After the first 30 minutes, abort trials if someone else held a better score (by this time) for at least 5 minutes." I'm not sure whether this will lead to better results than epoch-based pruning, but I'm hopeful. I hope this helps other people. You'll need to use |
This issue has not seen any recent activity. |
In terms of Again, I'm not familiar with pl, so I'm not sure. |
I don't understand what you mean. How can a trial look up its own When evaluating whether the current trial should be pruned, I need to look at the other trials' |
Sorry, I did not look at |
This issue has not seen any recent activity. |
I just do |
@droukd Sorry, what? How does that help? |
This issue has not seen any recent activity. |
In the end, I found it was easier to calculate "time elapsed" myself and add that factor into the objective value. This way faster runs score higher than slower ones and we don't need to play around with the pruner logic. Unless I missed something, this should be equivalent to the original request. |
I am also interested in a time-based pruner. I assume it would be useful, but I am very open for any arguments! Tweaking the objective seems quite unstable as a solution. Guess its very easy that way to mess up the weight parameter, isn't it? To be honest the second code example from @cowwoc is a hard to get as optuna noob :( (Maybe some further comments could help) Further more I have an other approach to suggest which might maybe be easier.
|
Reopening the issue as there is more interest. |
@UntotaufUrlaub Regarding the objective function messing up the weights, I use a dedicated objective function for Optuna and a separate one for training, validation. My understanding is that the time component will only influence which hyperparameters get chosen, as opposed to influencing the weights directly. That said, it would be nice to take this out of the objective function if possible. |
@cowwac separating it that way is pretty clever. So the optuna objective function is a wrapper around your training function which adds some penalty for the consumed time? So you have to configure how big the penalty is manually? |
Right, so what I did was something along the lines of:
I wanted the penalty to be a value between 0 and 1 so it wouldn't overwhelm the actual |
Makes sense, but to be honest feels a little like a work around and getting the right range for the penalty seems cumbersome to me :( |
I don't fully understand how/where your approach uses the real objective value. I can see the times being used but the former isn't clear to me. |
Am using a callback every "x" updates and than report the accuracy / the current value of my objective
The report is stored into the intermediate_values map using the time diff as key and the accuracy as value (if i got it right). |
And how does the |
Maybe an example makes thing a clearer. Now the next trial runs slower and reports after 2.5 minutes an accuracy of 0.3. Then using the for loop the interval in the old trial which surrounds the current step is found as [2, 3]. This is translated to the memorized objective values [0.3, 0.5] in line |
Thank you for the explanation. I think it makes sense so long as you implement some sort of value interpolation. As you hinted above, the time elapsed between epochs may vary drastically between trials. In one run it could be 100ms per epoch. In another, it could be 10 seconds. Ideally, we want to apple the objective value at time X (or between time X and Y) across both runs, even if the exact key does not exist. |
This issue has not seen any recent activity. |
If you insist,
I am happy with existing pruners. |
@droukd I think you missed the point of this pruner. The goal isn't to prune trials that run longer than others. The goal is to prune trials that take more time to achieve the same score as others achieved earlier. So, the objective function score is important but so is the time at which the score was reached. |
@cowwoc thank you for taking the time to understand my approach :)
Currently I am just using the mean of the interval as interpolation, but linear interpolation or even something more advanced would sure be better! (My intervals are quite small i hoped for the mean to be enough, but i should update it.) @droukd |
I was thinking, I am getting what is expected. Although, I am reporting integer.
Maybe rounding can help.
If someone able to quickly set up testing code, that would be amazing. As not only spent time is important, but combination of time, score and hyper-parameters (different set of parameters result in different performance — one model can report “100 score” with slow start, another with slow “90-to-100 jump”), I let it run for quite long; with different results. |
This issue has not seen any recent activity. |
This issue was closed automatically because it had not seen any recent activity. If you want to discuss it, you can reopen it freely. |
Out of curiosity, why isn't there a time-based pruner?
Justification: Some hyperparameters cause epochs to run much longer than others and while it is true that the objective value might be better per epoch it might still be a losing strategy from a time perspective.
Does it make sense to compare trials by time instead of by epochs? Meaning: consider a trial better if its objective value is better after the same amount of time (regardless of how many epochs elapsed).
The text was updated successfully, but these errors were encountered: