Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make kl_div accept target in log space #34586

Closed
wants to merge 26 commits into from

Conversation

nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented Mar 11, 2020

Fixes 32520, implements 34536.

Here are some benchmarks:

import torch
import torch.nn.functional as F
from IPython import get_ipython

ipython = get_ipython()

torch.set_num_threads(1)

for d in [5, 10, 20, 50, 100, 1000]:
    i = torch.rand(d, d)
    t = torch.rand(d, d)
    print(f"Size: {d}x{d}")
    ipython.magic("timeit F.kl_div(i, t, reduction='none', log_target=False)")
    ipython.magic("timeit F.kl_div(i, t.log(), reduction='none', log_target=True)")

Output:

Size: 5x5
16 µs ± 33 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
8.24 µs ± 17.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Size: 10x10
16.7 µs ± 17.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
8.7 µs ± 20.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Size: 20x20
17.7 µs ± 47.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.7 µs ± 28.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Size: 50x50
23.6 µs ± 60.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
15 µs ± 33.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Size: 100x100
42.8 µs ± 223 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
34 µs ± 17.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Size: 1000x1000
3.9 ms ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.45 ms ± 364 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)

@dr-ci
Copy link

dr-ci bot commented Mar 11, 2020

💊 CircleCI build failures summary and remediations

As of commit cf4814f (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

XLA failure

Job pytorch_xla_linux_xenial_py3_6_clang7_build is failing. Please create an issue with title prefixed by [PT_BREAK] in pytorch/xla and link to to this PR. If you have questions, please reach out to @ailzhang / @dlibenzi / @JackCaoG.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 129 times.

@nikitaved
Copy link
Collaborator Author

@ezyang, @ngimel, could you please tell what would be the best way to handle buffers in this case? I need to store that boolean flag for backward.

@nikitaved nikitaved force-pushed the nikved/kl_div_fix branch 2 times, most recently from 32acc63 to e6063ce Compare March 20, 2020 14:36
@nikitaved nikitaved changed the title [WIP] Make kl_div accept target in log space Make kl_div accept target in log space Mar 20, 2020
@nikitaved nikitaved requested a review from ezyang March 20, 2020 14:57
@yf225 yf225 added module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 20, 2020
@ezyang ezyang added the module: xla Related to XLA support label Mar 23, 2020
@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2020

cc @ailzhang @dlibenzi this is going to affect xla bindings

@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2020

Needs some benchmarks for the new codepath. Compare log_space=False/True

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests, benchmarks

@nikitaved
Copy link
Collaborator Author

nikitaved commented Mar 23, 2020

Needs some benchmarks for the new codepath. Compare log_space=False/True

If I sample from Dirichlet distributions of different size and then measure the performance, would that be sufficient?

@ezyang
Copy link
Contributor

ezyang commented Mar 24, 2020

If I sample from Dirichlet distributions of different size and then measure the performance, would that be sufficient?

SGTM

@nikitaved
Copy link
Collaborator Author

OK, tests are fixed. TODO benchmarks..

@nikitaved nikitaved force-pushed the nikved/kl_div_fix branch 4 times, most recently from d667941 to 97e195c Compare March 25, 2020 09:09
@nikitaved
Copy link
Collaborator Author

Hopefully fixed now!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 35cdb78.

@@ -294,7 +294,8 @@ class KLDivLoss(_Loss):

As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain
*log-probabilities* and is not restricted to a 2D Tensor.
The targets are given as *probabilities* (i.e. without taking the logarithm).
The targets are given as *probabilities* by default, but could be passed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this isn't really accurate (and I realize this was before your change) -- the targets are given however the users give them. They are interpreted as probabilities by default.

@nikitaved nikitaved mentioned this pull request Apr 7, 2020
facebook-github-bot pushed a commit that referenced this pull request Apr 9, 2020
Summary:
Fixes doc for KLDivLoss as per [this comment](#34586 (comment)).
Pull Request resolved: #36137

Differential Revision: D20932395

Pulled By: gchanan

fbshipit-source-id: ecc395e6bc689fbf758e2cdca946049de8963856
ashishfarmer pushed a commit to ashishfarmer/pytorch that referenced this pull request Apr 13, 2020
Summary:
Fixes doc for KLDivLoss as per [this comment](pytorch#34586 (comment)).
Pull Request resolved: pytorch#36137

Differential Revision: D20932395

Pulled By: gchanan

fbshipit-source-id: ecc395e6bc689fbf758e2cdca946049de8963856
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: xla Related to XLA support open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Numerical problems with torch.nn.functional.kl_div
7 participants