-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add new reduction mode in kl_div #14457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get an expect test that this averages along batch dimension?
@ssnl I added a test to compare it with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/nn/functional.py
Outdated
specifying either of those two args will override :attr:`reduction`. Default: 'mean' | ||
'none' | 'batchmean' | 'sum' | 'mean'. 'none': no reduction will be applied, | ||
'batchmean': the sum of the output will be divided by the number of | ||
batches in the output, 'sum': the output will be summed, 'mean': the output will be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"the number of batches" -> "the batch size" or "the number of elements in the input batch"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think batch_mean
should be an enum element. It will also introduce weird behavior for the losses still defined in TH if used with batch_mean
since if-statements there are not written with batch_mean
in mind.
aten/src/ATen/core/Reduction.h
Outdated
None, // Do not reduce | ||
Mean, // (Possibly weighted) mean of losses | ||
Sum, // Sum losses | ||
BatchMean, // Mean over batches. = Sum / batchsize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just hack this up in python? It should only be a workaround for kl_div and doesn't make sense for other losses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm now wondering if reduction='none' is wrong.. But this generally LGTM.
torch/nn/functional.py
Outdated
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) | ||
else: | ||
if reduction == 'mean': | ||
warnings.warn("reduction=mean doesn't give the true kl divergence value. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to put quotes around mean
and batchmean
. Also, I think this can be worded more clearly like
reduction: "mean" divides the total loss by both the batch size and the support size.
"batchmean" divides only by the batch size, and aligns with the KL divergence math definition.
"mean" will be changed to behave the same as "batchmean" in the next major release.
torch/nn/functional.py
Outdated
'mean': the output will be divided by the number of elements in the output | ||
Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, | ||
and in the meantime, specifying either of those two args will override :attr:`reduction`. | ||
Note: `reduction='mean'` doesn't return the true kl divergence value, please use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make this a more obvious note using .. note::
. Same for the module doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #6622 .
We used to average over all elements for kl divergence, which is not aligned with its math definition.
This PR corrects the default reduction behavior of KL divergence that it now naverages over batch dimension.
reduction=mean
averages over batch dimension. While for most other loss functions,reduction=mean
averages over all elements.batchmean
which has the correct behavior for KL. Add a warning to makebatchmean
as default for KL instead ofmean
in next major release.