Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Epsilon decay implementation is inconsistent with description #478

Open
jaschau opened this issue Dec 15, 2023 · 0 comments
Open

Epsilon decay implementation is inconsistent with description #478

jaschau opened this issue Dec 15, 2023 · 0 comments

Comments

@jaschau
Copy link

jaschau commented Dec 15, 2023

Describe the bug

Either I am misunderstanding the documentation or the implementation of epsilon scheduling is not consistent with the description.

The documentation of the epsilon scheduler class currently states

That value [the regularization strength] is either the final, targeted regularization, or one that is larger, obtained by geometric decay of an initial value that is larger than the intended target.

To Reproduce

schedule = ott.geometry.epsilon_scheduler.Epsilon(target=1e-5, init=1., decay=0.5)
schedule.at(iteration=0)
# expected: 1.
# actual: 1e-5

The implementation is

  def at(self, iteration: Optional[int] = 1) -> float:
    """Return (intermediate) regularizer value at a given iteration."""
    if iteration is None:
      return self.target
    # check the decay is smaller than 1.0.
    decay = jnp.minimum(self._decay, 1.0)
    # the multiple is either 1.0 or a larger init value that is decayed.
    multiple = jnp.maximum(self._init * (decay ** iteration), 1.0)
    return multiple * self.target

Could it be that the last two lines should be replaced by

return jnp.maximum(self._init * (decay ** iteration), self.target)

I am using ott-jax = 0.4.4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant