Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time-based pruning #2873

Closed
cowwoc opened this issue Aug 26, 2021 · 34 comments
Closed

Time-based pruning #2873

cowwoc opened this issue Aug 26, 2021 · 34 comments
Labels
feature Change that does not break compatibility, but affects the public interfaces. stale Exempt from stale bot labeling.

Comments

@cowwoc
Copy link
Contributor

cowwoc commented Aug 26, 2021

Out of curiosity, why isn't there a time-based pruner?

Justification: Some hyperparameters cause epochs to run much longer than others and while it is true that the objective value might be better per epoch it might still be a losing strategy from a time perspective.

Does it make sense to compare trials by time instead of by epochs? Meaning: consider a trial better if its objective value is better after the same amount of time (regardless of how many epochs elapsed).

@cowwoc cowwoc added the feature Change that does not break compatibility, but affects the public interfaces. label Aug 26, 2021
@cowwoc
Copy link
Contributor Author

cowwoc commented Aug 27, 2021

This is the best I could do for now:

class TimeBasedPruner(BasePruner):
    """
    Prunes a trial if another trial achieved a better objective value using the same or less time.
    """
    _logger = optuna.logging.get_logger(__name__)

    def __init__(self, min_time_elapsed: timedelta = 0) -> None:
        """
        :param min_time_elapsed: the minimum amount of time that a trial must run for before being
        considered for prunning
        """
        self.min_time_elapsed = min_time_elapsed

    @staticmethod
    def _get_best_intermediate_result_over_steps(trial: FrozenTrial,
                                                 direction: StudyDirection) -> ndarray:
        values = np.asarray(list(trial.intermediate_values.values()), dtype=float)
        if direction == StudyDirection.MAXIMIZE:
            return np.nanmax(values)
        return np.nanmin(values)

    def prune(self, study: Study, trial: FrozenTrial) -> bool:
        time_elapsed = datetime.now() - trial.datetime_start
        if time_elapsed < self.min_time_elapsed:
            return False
        trials = study.get_trials(deepcopy=False)
        direction = study.direction
        trial_best = self._get_best_intermediate_result_over_steps(trial, direction)
        better_trial = None
        better_symbol = None
        for other in trials:
            if other.state != TrialState.COMPLETE or other.duration > time_elapsed:
                continue
            other_best = self._get_best_intermediate_result_over_steps(other, direction)
            if direction == StudyDirection.MAXIMIZE:
                if other_best > trial_best:
                    better_trial = other
                    better_symbol = ">"
                    break
            else:
                if other_best < trial_best:
                    better_trial = other
                    better_symbol = "<"
                    break
        if better_trial is not None:
            self._logger.info(f"Pruning trial {trial.number} because trial {better_trial.number} had a "
                              f"better objective value ({better_trial.value} {better_symbol} {trial_best}) "
                              f"after duration {better_trial.duration}")
            return True
        return False

Currently I can only compare against the full duration of other trials. Ideally, I'd like to be able to look up the run times of trial.intermediate_values elements and prune the current trial if one of these intermediate values had a better score earlier than the current trial. Would you consider adding this information?

@nzw0301
Copy link
Member

nzw0301 commented Aug 28, 2021

Thank you for your proposal!

By the way, we can store such information by using trial.user_attr:

def objective(trial):
    # sampling hyper-parameters

    trial.set_user_attr("computing_time_per_step", [])
    for step in range(epochs):
        # some training code
        # computing `computing_time_per_step`
        trial.set_user_attr("computing_time_per_step", trial.user_attrs["computing_time_per_step"] + [computing_time_per_step])
        trial.report(score, step)        
        if trial.should_prune():
            raise optuna.TrialPruned()
    return score

@cowwoc
Copy link
Contributor Author

cowwoc commented Aug 30, 2021

@nzw0301 Thank you. Is there any way to get the current model's global step from within a pruner? It doesn't look like pruners have access to the model. As well, I don't actually want to save this state when the trial gets saved to disk. I need it to evaluate whether pruning should take place.

@cowwoc
Copy link
Contributor Author

cowwoc commented Aug 30, 2021

Is this the only way to go?

class PyTorchLightningPruningWithTimeCallback(PyTorchLightningPruningCallback):
    """
    A PyTorchLightningPruningCallback that sets "train_duration" for each global step
    """

    def __init__(self, trial: Trial, monitor: str):
        super().__init__(trial, monitor)
        self._trial.user_attrs["global_step"] = 0

    def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
              stage: Optional[str] = None) -> None:
        self.start_time = datetime.now()

    def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        train_durations = self._trial.user_attrs.get("train_durations")
        if train_durations is None:
            train_durations = {}
        train_durations[trainer.global_step] = (datetime.now() - self.start_time).total_seconds()
        self._trial.set_user_attr("train_durations", train_durations)
        self._trial.set_user_attr("global_step", pl_module.global_step)

    def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        del self._trial.user_attrs["global_step"]

@nzw0301
Copy link
Member

nzw0301 commented Aug 31, 2021

Is there any way to get the current model's global step from within a pruner?

I think no. To access the model's parameter, pl_module.global_step, from a pruner, we need to save it as user_attr or system_attr.

code

I'm not familiar with pytorch-lightning but the code looks nice to me.

Proposal

The total number of intermediate values in study probably become huge in practice. So saving timestamps per intermediate value might take a lot of disk space. I'd suggest using user or system attributes for time-based pruning.

@cowwoc
Copy link
Contributor Author

cowwoc commented Aug 31, 2021

What is the best way to remove an intermediate value prior to the trial being saved to disk? Is the above on_train_end() implementation the way to go?

@cowwoc
Copy link
Contributor Author

cowwoc commented Sep 3, 2021

Here is what I've got so far:

class PyTorchLightningPruningWithTimeCallback(PyTorchLightningPruningCallback):
    """
    A PyTorchLightningPruningCallback that sets "train_duration" for each global step

    Args:
        trial:
            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
            objective function.
        monitor:
            An evaluation metric for pruning, e.g., ``val_loss`` or
            ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
            ``pytorch_lightning.LightningModule.training_step`` or
            ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
            how this dictionary is formatted.
    """

    def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
        super().__init__(trial, monitor)

        self._trial = trial
        self.monitor = monitor

    def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
        self._trial.set_user_attr("tensorboard_version", trainer.logger.version)
        self.start_time = datetime.now()

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        epoch_to_time_elapsed = self._trial.user_attrs.get("epoch_to_time_elapsed", {})
        current_epoch = trainer.current_epoch
        epoch_to_time_elapsed[current_epoch] = (datetime.now() - self.start_time).total_seconds()
        self._trial.set_user_attr("epoch_to_time_elapsed", epoch_to_time_elapsed)
        self._trial.set_user_attr("current_epoch", current_epoch)
        super(PyTorchLightningPruningWithTimeCallback, self).on_validation_end(trainer, pl_module)

    def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
        del self._trial.user_attrs["current_epoch"]


class TimeBasedPruner(BasePruner):
    """
    Prunes a trial if another trial achieved a better objective value using the same or less time.
    """
    _logger = optuna.logging.get_logger(__name__)

    def __init__(self, min_duration: timedelta = 0, max_duration: timedelta = timedelta.max,
                 patience: timedelta = 0) -> None:
        """
        :param min_duration: the minimum amount of time that a trial may run
        :param max_duration: the maximum amount of time that a trial may run
        :param patience: the amount of time a trial has to match a superior objective score held by other
        trials
        """
        self.min_time_elapsed = min_duration
        self.max_time_elapsed = max_duration
        self.patience = patience
        self.trial_to_other_to_epoch_processed = {}

    @staticmethod
    def get_stop_epoch(trial: FrozenTrial, time_elapsed: timedelta, since_epoch: int) -> int:
        """
        Returns the epoch completed after `time_elapsed`.

        :param trial: a trial
        :param time_elapsed: a relative time from the beginning of the trial
        :param since_epoch: the first epoch to search from
        :return: `len(trial.intermediate_values.values())` if no match was found
        """
        epoch_to_time_elapsed: dict = trial.user_attrs["epoch_to_time_elapsed"]
        max_epoch: int = len(trial.intermediate_values.values())
        for epoch, completion_time_elapsed in itertools.islice(epoch_to_time_elapsed.items(), since_epoch,
                                                               max_epoch):
            # Optuna invokes json.dumps() which converts dictionary keys to str
            result = int(epoch)
            completion_time_elapsed = timedelta(seconds=completion_time_elapsed)
            if completion_time_elapsed > time_elapsed:
                return result
        return max_epoch

    def get_pruning_context(self, other: FrozenTrial, other_start_epoch: int, other_stop_epoch: int,
                            trial: FrozenTrial, trial_start_epoch: int,
                            best_value: Callable[[Iterable[float]], float]) -> Optional[
        tuple[float, timedelta, float]]:
        """
        Evaluates whether `other` performed better than `trial`.

        :param other: the other trial
        :param other_start_epoch: the index to start searching at (inclusive)
        :param other_stop_epoch: the index to stop searching at (exclusive)
        :param trial: the current trial
        :param trial_start_epoch: the index to start searching `trial` at
        :param best_value: a function that maps an sequence of best objective values and returns the best one
        :return: None if `other` performed worse than `trial`; otherwise, the better trial's objective score,
        the time associated with the score, and the current trial's best score
        """
        other_epoch_to_time_elapsed: dict = other.user_attrs["epoch_to_time_elapsed"]
        for other_epoch in range(other_start_epoch, other_stop_epoch):
            if other_epoch not in other.intermediate_values:
                continue
            # Optuna invokes json.dumps() which converts dictionary keys to str
            other_time_elapsed = timedelta(seconds=other_epoch_to_time_elapsed[str(other_epoch)])
            trial_stop_epoch = self.get_stop_epoch(trial, other_time_elapsed + self.patience,
                                                   trial_start_epoch)
            trial_candidates = list(itertools.islice(trial.intermediate_values.values(), trial_start_epoch,
                                                     trial_stop_epoch))
            if len(trial_candidates) == 0:
                continue
            other_value = other.intermediate_values[other_epoch]
            trial_best_value = best_value(trial_candidates)
            if best_value([other_value, trial_best_value]) != trial_best_value:
                return other_value, other_time_elapsed, trial_best_value
        return None

    def prune(self, study: Study, trial: FrozenTrial) -> bool:
        # If validation not invoked after every training step, user_attrs might not be set during the first
        # couple steps.
        epoch_to_time_elapsed: dict = trial.user_attrs.get("epoch_to_time_elapsed", {})
        if len(epoch_to_time_elapsed) == 0:
            return False
        current_epoch = trial.user_attrs["current_epoch"]
        # Optuna invokes json.dumps() which converts dictionary keys to str
        time_elapsed = timedelta(seconds=epoch_to_time_elapsed[str(current_epoch)])
        if time_elapsed < self.min_time_elapsed or time_elapsed < self.patience:
            return False
        if time_elapsed > self.max_time_elapsed:
            return True
        trial_to_start_epoch = trial.user_attrs.get("trial_to_start_epoch", {})
        trial_start_epoch = trial_to_start_epoch.get(trial.number, 0)
        time_elapsed_before_patience = time_elapsed - self.patience
        if study.direction == StudyDirection.MAXIMIZE:
            best_value = max
            better_symbol = ">"
        else:
            best_value = min
            better_symbol = "<"

        trials = study.get_trials(deepcopy=False)
        for other in trials:
            if other.state != TrialState.COMPLETE:
                continue
            if len(other.intermediate_values) == 0:
                continue
            other_start_epoch = trial_to_start_epoch.get(other.number, 0)
            other_stop_epoch = self.get_stop_epoch(other, time_elapsed_before_patience,
                                                   other_start_epoch)
            context = self.get_pruning_context(other, other_start_epoch, other_stop_epoch, trial,
                                               trial_start_epoch, best_value)
            trial_to_start_epoch[other.number] = other_stop_epoch
            trial.set_user_attr("trial_to_start_epoch", trial_to_start_epoch)
            if context is not None:
                other_value, other_time_elapsed, trial_best_value = context
                self._logger.info(f"Pruning trial {trial.number} after {time_elapsed} because trial "
                                  f"{other.number} had a better objective value "
                                  f"({other_value} {better_symbol} {trial_best_value}) after "
                                  f"{other_time_elapsed}")
                return True
        trial_stop_epoch = self.get_stop_epoch(trial, time_elapsed_before_patience, trial_start_epoch)
        trial_to_start_epoch[trial.number] = trial_stop_epoch
        trial.set_user_attr("trial_to_start_epoch", trial_to_start_epoch)
        return False

This allows you to abort trials if another trial has a better score by the current epoch. You can configure min_duration to give trials time to "warm up" before being considered for prunning. max_duration abort trials if they've run too long. patience only considers better scores in other trials if they've held that score for a minimum amount of time.

So for example you could say: "Run trials for a maximum of 1 hour each. After the first 30 minutes, abort trials if someone else held a better score (by this time) for at least 5 minutes."

I'm not sure whether this will lead to better results than epoch-based pruning, but I'm hopeful. I hope this helps other people.

You'll need to use PyTorchLightningPruningWithTimeCallback instead of PyTorchLightningPruningCallback to provide the pruner with the timing information it needs. (I would like to find a way for the two callbacks to run alongside one another but I can't think of a way)

@github-actions
Copy link
Contributor

This issue has not seen any recent activity.

@github-actions github-actions bot added stale Exempt from stale bot labeling. and removed stale Exempt from stale bot labeling. labels Sep 19, 2021
@nzw0301
Copy link
Member

nzw0301 commented Sep 25, 2021

In terms of PyTorchLightningPruningWithTimeCallback, I think we do not need to use trial.user_attr or system_attr because each tracked time does not depend on other trials' values. So I suggest defining some attributes to store epoch_to_time_elapsed and current_epoch. That will be a better way, I suppose. By doing so, we might make the code faster since we do not need to access the trial's attribute especially when we use the database for Optuna's storage.

Again, I'm not familiar with pl, so I'm not sure.

@cowwoc
Copy link
Contributor Author

cowwoc commented Sep 26, 2021

I think we do not need to use trial.user_attr or system_attr because each tracked time does not depend on other trials' values

I don't understand what you mean. How can a trial look up its own epoch_to_time_elapsed and current_epoch without user attributes? I don't think I can set these attributes directly on the FrozenTrial object and even if I could there doesn't seem to be a guarantee that these values will get saved anywhere.

When evaluating whether the current trial should be pruned, I need to look at the other trials' epoch_to_time_elapsed attribute...

@nzw0301
Copy link
Member

nzw0301 commented Sep 26, 2021

Sorry, I did not look at TimeBasedPruner. I was wrong.

@github-actions
Copy link
Contributor

This issue has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Oct 10, 2021
@droukd
Copy link

droukd commented Jan 19, 2022

I just do trial.report(accuracy, budget).

@cowwoc
Copy link
Contributor Author

cowwoc commented Jan 19, 2022

@droukd Sorry, what? How does that help? trial.report() does not prune trials and certainly does not look at time.

@github-actions github-actions bot removed the stale Exempt from stale bot labeling. label Jan 19, 2022
@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2022

This issue has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Feb 3, 2022
@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 3, 2022

In the end, I found it was easier to calculate "time elapsed" myself and add that factor into the objective value. This way faster runs score higher than slower ones and we don't need to play around with the pruner logic. Unless I missed something, this should be equivalent to the original request.

@UntotaufUrlaub
Copy link

UntotaufUrlaub commented Feb 7, 2022

I am also interested in a time-based pruner. I assume it would be useful, but I am very open for any arguments!

Tweaking the objective seems quite unstable as a solution. Guess its very easy that way to mess up the weight parameter, isn't it?

To be honest the second code example from @cowwoc is a hard to get as optuna noob :( (Maybe some further comments could help)

Further more I have an other approach to suggest which might maybe be easier.
The idea is based on the percentile pruner. I used the report() function but with timestamps instead of steps, and then figured out my main problem to be that the "step" is used as a key and only exact matches are considered. So my approach is to find the interval of the other trial (so the key/step/timestamp above and below) and prune based if I am above the mean of the values of that keys. Guess that's quite basic, but should be good enough as basis.
It could be easily enhanced by more advanced forms of interpolation.
Please let me know any opinion about this!

class TimestampBaseMedianPruner(BasePruner):
    """
    Inspired by:
    https://optuna.readthedocs.io/en/latest/_modules/optuna/pruners/_percentile.html#PercentilePruner
    https://optuna.readthedocs.io/en/latest/tutorial/20_recipes/006_user_defined_pruner.html

    Not pruning on matching the step exactly,
    but on comparing to the values reported for the related intervals
    """

    def __init__(
        self,
        n_startup_trials: int = 5,
        n_warmup_steps: int = 0,
    ) -> None:

        if n_startup_trials < 0:
            raise ValueError(
                "Number of startup trials cannot be negative but got {}.".format(n_startup_trials)
            )
        if n_warmup_steps < 0:
            raise ValueError(
                "Number of warmup steps cannot be negative but got {}.".format(n_warmup_steps)
            )

        self._n_startup_trials = n_startup_trials
        self._n_warmup_steps = n_warmup_steps

    def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool:

        all_trials = study.get_trials(deepcopy=False)
        n_trials = len([t for t in all_trials if t.state == TrialState.COMPLETE])

        if n_trials == 0:
            return False

        if n_trials < self._n_startup_trials:
            return False

        step = trial.last_step
        if step is None:
            return False

        n_warmup_steps = self._n_warmup_steps
        if step < n_warmup_steps:
            return False

        this_score = trial.intermediate_values[step]
        direction = study.direction

        mean_list = []
        for trial in all_trials[:-1]:
            value_dict = trial.intermediate_values
            keys = list(value_dict.keys())
            keys.sort()
            maxi = keys[-1]
            if maxi < step: # this trail terminated earlier and can't be used for comparison.
                break

            # find the next greater timestamp / key to determine the interval,
            # take that key and the previous one, map them to their values and than store their mean.
            for i in range(len(keys)):
                if step < keys[i]:
                    v = [value_dict[i] for i in keys[(i - 1):(i + 1)]]
                    mean_list.append(mean(v))   # store the relevant value for the trial.
                    break

        if len(mean_list) < self._n_startup_trials:
            return False

        # todo: a greedier strategy could be to only take the self._n_startup_trials * 2 best elements
        #  of the mean list to calculate the mean_res, or even just take the best one.
        mean_res = mean(mean_list)

        if direction == StudyDirection.MAXIMIZE:
            return mean_res > this_score
        return mean_res < this_score

@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 7, 2022

Reopening the issue as there is more interest.

@cowwoc cowwoc reopened this Feb 7, 2022
@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 7, 2022

@UntotaufUrlaub Regarding the objective function messing up the weights, I use a dedicated objective function for Optuna and a separate one for training, validation. My understanding is that the time component will only influence which hyperparameters get chosen, as opposed to influencing the weights directly.

That said, it would be nice to take this out of the objective function if possible.

@UntotaufUrlaub
Copy link

@cowwac separating it that way is pretty clever. So the optuna objective function is a wrapper around your training function which adds some penalty for the consumed time? So you have to configure how big the penalty is manually?

@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 7, 2022

Right, so what I did was something along the lines of:

trial_loss = val_loss + training_duration_as_percentage where training_duration_as_percentage = average_training_seconds_per_input / MAX_TRAINING_SECONDS_PER_INPUT

average_training_seconds_per_input is a average amount of time that each epoch takes to process. It is updated once per epoch. MAX_TRAINING_SECONDS_PER_INPUT is a guess of the maximum amount of time an epoch can take.

I wanted the penalty to be a value between 0 and 1 so it wouldn't overwhelm the actual val_loss but you could play around with these numbers to make them work any way you want. This is still a work in progress on my end so I can't tell you for certain that it really works. It just makes theoretical sense to me :)

@UntotaufUrlaub
Copy link

Makes sense, but to be honest feels a little like a work around and getting the right range for the penalty seems cumbersome to me :(
What do you think about my approach?

@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 7, 2022

I don't fully understand how/where your approach uses the real objective value. I can see the times being used but the former isn't clear to me.

@UntotaufUrlaub
Copy link

Am using a callback every "x" updates and than report the accuracy / the current value of my objective

trial.report(accuracy, time.time() - start_timestamp)
    if trial.should_prune():
        raise optuna.TrialPruned()

The report is stored into the intermediate_values map using the time diff as key and the accuracy as value (if i got it right).
Is it any clearer? Please just let me know everything unclear :)

@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 7, 2022

And how does the prune method use the objective value for pruning? Does it only look at the timing?

@UntotaufUrlaub
Copy link

UntotaufUrlaub commented Feb 7, 2022

Maybe an example makes thing a clearer.
Lets say the objective is an accuracy, so how many samples are classified correct, so we have a value between zero and one.
So the first trial could have reported for example:
{1: 0.1, 2: 0.3, 3: 0.5, 4: 0.7}
The keys are the time diffs (here for simplicity in strait minutes) and values are the accuracy.
So in the first trial after 2 minutes we had an accuracy of 30%.
This is stored in trial.intermediate_values (of course the old trial, here trial = all_trials[0])

Now the next trial runs slower and reports after 2.5 minutes an accuracy of 0.3.
Then trial.should_prune() is called which calls the prune method.
That accesses the described dictionary as trial.intermediate_values. The last step / timestamp of the current trial is found in step = trial.last_step # currently 2.5
and the corresponding value in
this_score = trial.intermediate_values[step] # currently 0.3; here trial = all_trials[1]

Then using the for loop the interval in the old trial which surrounds the current step is found as [2, 3]. This is translated to the memorized objective values [0.3, 0.5] in line
v = [value_dict[i] for i in keys[(i - 1):(i + 1)]] # note: value_dict = trial.intermediate_values
then the mean is calculated. So this interval is represented as 0.4. As there is only one old trial this is the value to compare the current objective value to.
So here it would prune.

@github-actions github-actions bot removed the stale Exempt from stale bot labeling. label Feb 7, 2022
@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 8, 2022

Thank you for the explanation. I think it makes sense so long as you implement some sort of value interpolation. As you hinted above, the time elapsed between epochs may vary drastically between trials. In one run it could be 100ms per epoch. In another, it could be 10 seconds. Ideally, we want to apple the objective value at time X (or between time X and Y) across both runs, even if the exact key does not exist.

@github-actions
Copy link
Contributor

This issue has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Feb 22, 2022
@droukd
Copy link

droukd commented Feb 28, 2022

If you insist,

start = time.time()
<...>
budget = time.time() - start
trial.report(accuracy, budget)
if trial.should_prune():
    raise optuna.TrialPruned()

I am happy with existing pruners.

@cowwoc
Copy link
Contributor Author

cowwoc commented Feb 28, 2022

@droukd I think you missed the point of this pruner. The goal isn't to prune trials that run longer than others. The goal is to prune trials that take more time to achieve the same score as others achieved earlier.

So, the objective function score is important but so is the time at which the score was reached.

@github-actions github-actions bot removed the stale Exempt from stale bot labeling. label Feb 28, 2022
@UntotaufUrlaub
Copy link

UntotaufUrlaub commented Mar 8, 2022

@cowwoc thank you for taking the time to understand my approach :)

I think it makes sense so long as you implement some sort of value interpolation.

Currently I am just using the mean of the interval as interpolation, but linear interpolation or even something more advanced would sure be better! (My intervals are quite small i hoped for the mean to be enough, but i should update it.)

@droukd
Tanks for contributing :)
Your approached is what i first tried, but it didn't work out for me. The issue seemed to be that the keys (in your example budget) are only checked at exact matches. So if one trail reports the first value after 6.78 seconds and the next reports after 6.9 seconds the both values wont be compared.
The percentage pruner just thinks for that budget there was nothing reported yet and so doesn't prune.
Reporting time is always subject to some jitter so exactly same budgets are quite unlikely. That's the cause why something more flexible is needed and we came up with a custom implementation.
You could ether take the closest key or an interpolation of some surrounding interval is I purposed.

@droukd
Copy link

droukd commented Mar 16, 2022

The issue seemed to be that the keys (in your example budget) are only checked at exact matches.

I was thinking, I am getting what is expected. Although, I am reporting integer.

step (int) – Step of the trial (e.g., Epoch of neural network training). Note that pruners assume that step starts at zero. For example, MedianPruner simply checks if step is less than n_warmup_steps as the warmup mechanism.

Screenshot 2022-03-16 at 19 31 06

Maybe rounding can help.

budget = int(budget) # round(budget) # int(budget*100)
""" round every step to three/make steps “even”, “matching”: """
# int(budget/3)*3

If someone able to quickly set up testing code, that would be amazing.

As not only spent time is important, but combination of time, score and hyper-parameters (different set of parameters result in different performance — one model can report “100 score” with slow start, another with slow “90-to-100 jump”), I let it run for quite long; with different results.
But if I read something wrong, please lemme know.

@github-actions
Copy link
Contributor

This issue has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Mar 30, 2022
@github-actions
Copy link
Contributor

This issue was closed automatically because it had not seen any recent activity. If you want to discuss it, you can reopen it freely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Change that does not break compatibility, but affects the public interfaces. stale Exempt from stale bot labeling.
Projects
None yet
Development

No branches or pull requests

4 participants