Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numerically stable log1mexp = log(1 - exp(-|x|)) function #39242

Open
simonepri opened this issue May 29, 2020 · 5 comments
Open

Add numerically stable log1mexp = log(1 - exp(-|x|)) function #39242

simonepri opened this issue May 29, 2020 · 5 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix function request A request for a new function or the addition of new arguments/modes to an existing function. module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@simonepri
Copy link

simonepri commented May 29, 2020

馃殌 Feature

Add the implementation for a numerically stable log(1 - exp(-|x|)) function.

See https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf

Ref https://www.rdocumentation.org/packages/VGAM/versions/1.1-3/topics/log1mexp

Motivation

The function can be implemented using other pytorch functions but it might be tricky to get it right.

Some implementations found on GitHub:

@wouterkool/estimating-gradients-without-replacement/blob/9d8bf8b/bernoulli/gumbel.py#L7-L11

@visinf/n3net/blob/5d5883a/src_denoising/models/non_local.py#L94-L108

cc: @wouterkool @visinf

@simonepri simonepri changed the title Add stable log1mexp = log(1 - exp(-|x|)) function Add numerically stable log1mexp = log(1 - exp(-|x|)) function May 29, 2020
@gchanan gchanan added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix needs research We need to decide whether or not this merits inclusion, based on research world module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 1, 2020
@mruberry
Copy link
Collaborator

mruberry commented Jun 9, 2020

Hey @simonepri, thanks for the suggestion. Would you be interested in submitting a PR for it?

Follow-up question, how numerically stable are PyTorch's log1p and expm1 implementations?

@xiaomengy
Copy link
Contributor

If no one took this, I think I can help add them at my spare time.

@mruberry
Copy link
Collaborator

@BIT-silence We already have log1p and expm1, which can be used to write a reasonable stable version of this function (so long as they themselves are stable). I'd start by checking the numerical stability of those ops.

@mruberry mruberry removed the needs research We need to decide whether or not this merits inclusion, based on research world label Jun 23, 2020
@cmpute
Copy link
Contributor

cmpute commented Jun 25, 2020

Thanks for the links provided! This function is also useful in my work!

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. module: numerical-stability Problems related to numerical stability of operations and removed module: operators (deprecated) labels Oct 7, 2020
@Balandat
Copy link
Contributor

If anyone still wants to add this (courtesy of @SebastianAment):

def log1mexp(x: Tensor) -> Tensor:
    """Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
    See [Maechler2012accurate]_ for details.
    """
    mask = -math.log(2) < x  # x < 0
    return torch.where(
        mask,
        (-x.expm1()).log(),
        (-x.exp()).log1p(),
    )

[Maechler2012accurate] M. M盲chler. Accurately Computing log (1 - exp (-| a|)). Assessed by the Rmpfr package. Technical report, 2012. https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix function request A request for a new function or the addition of new arguments/modes to an existing function. module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants