Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch.xlogy #22656

Closed
colesbury opened this issue Jul 9, 2019 · 6 comments
Closed

Implement torch.xlogy #22656

colesbury opened this issue Jul 9, 2019 · 6 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@colesbury
Copy link
Member

馃殌 Feature

Implement torch.xlogy which computes x * log(y) if x != 0 and 0 if x == 0.

Motivation

Often one wants to compute x * log(x) or x * log(y) but define the result to be zero if x is 0 (instead of NaN). The current alternatives is to use torch.where to mask out NaN values or to add a small epsilon to x, like we do in binary cross entropy loss.

Additional context

This is implemented in both SciPy and TensorFlow:

https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.xlogy.html
https://www.tensorflow.org/api_docs/python/tf/math/xlogy

cc @IssamLaradji

@IssamLaradji
Copy link

perfect! thanks :)

@IssamLaradji
Copy link

will this work? i think it can be simplified further.

def xlogy(x, y):
        z = torch.zeros(())
        if tensor.device.type == "cuda":
             z.to(x.get_device())

        return x * torch.where(x == 0., z, torch.log(y))

@vishwakftw vishwakftw added feature A request for a proper, new feature. module: operators labels Jul 10, 2019
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 10, 2019
@ezyang
Copy link
Contributor

ezyang commented Jul 10, 2019

It's a good start, but ideally we would get it fused. Maybe the JIT can automatically fuse this code?

@fmassa
Copy link
Member

fmassa commented Jul 10, 2019

@ezyang I think the JIT can fuse pretty much everything (we can instead just use an expanded tensor as in Issam's example)

In [1]: import torch

In [2]: @torch.jit.script
   ...: def f(x, y):
   ...:     z = torch.zeros_like(x)
   ...:     return x * torch.where(x == 0, z, torch.log(y))
   ...:

In [3]: f(torch.rand(2).cuda(), torch.rand(2).cuda())
Out[3]: tensor([-0.1152, -0.0522], device='cuda:0')

In [4]: f.graph_for(torch.rand(2).cuda(), torch.rand(2).cuda())
Out[4]:
graph(%x.1 : Float(*),
      %y.1 : Float(*)):
  %z.1 : Float(*) = aten::zeros_like(%x.1) # <ipython-input-2-2db415bb6205>:3:9
  %8 : Float(*) = prim::FusionGroup_0(%x.1, %z.1, %y.1)
  return (%8)
with prim::FusionGroup_0 = graph(%0 : Float(*),
      %4 : Float(*),
      %7 : Float(*)):
  %9 : int = prim::Constant[value=0]() # <ipython-input-2-2db415bb6205>:4:33
  %10 : Byte(*) = aten::eq(%0, %9) # <ipython-input-2-2db415bb6205>:4:28
  %8 : Float(*) = aten::log(%7) # <ipython-input-2-2db415bb6205>:4:39
  %6 : Float(*) = aten::where(%10, %4, %8) # <ipython-input-2-2db415bb6205>:4:16
  %2 : Float(*) = aten::mul(%0, %6) # <ipython-input-2-2db415bb6205>:4:12
  return (%2)

@colesbury
Copy link
Member Author

To be clear the request is to implement this function as a public API, not to make sure that the JIT can fuse a multi-op implementation.

The Python implementation with @torch.jit.script is probably not sufficient (currently) for a public API:

  1. It doesn't work in C++
  2. JIT error messages are still substantially worse when things go wrong
  3. It's not fused for CPU

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Oct 3, 2019

Will gradients still be bad if (gradients of) torch.log(y) contains NaNs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants