Skip to content

Conversation

ailzhang
Copy link
Contributor

@ailzhang ailzhang commented Nov 28, 2018

Fixes #6622 .
We used to average over all elements for kl divergence, which is not aligned with its math definition.
This PR corrects the default reduction behavior of KL divergence that it now naverages over batch dimension.

  • In KL, default behavior reduction=mean averages over batch dimension. While for most other loss functions, reduction=mean averages over all elements.
  • We used to support scalar tensor as well. For BC purpose, we still support it, no reduction is performed on scalar tensor.
  • Added a new reduction mode called batchmean which has the correct behavior for KL. Add a warning to make batchmean as default for KL instead of mean in next major release.
  • [deprecated]I chose to not add a new reduction option, since "mean over batch dimension" is kinda special, and it only makes sense in few cases like KL. We don't want to explain why there's a option "batchmean" but it's not applicable for all other functions. I'm open to discussion on this one, as I cannot think of a perfect solution for this.

@ailzhang ailzhang added 1.0 module: bc-breaking Related to a BC-breaking change labels Nov 28, 2018
@ailzhang ailzhang requested a review from ssnl November 28, 2018 06:53
Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get an expect test that this averages along batch dimension?

@ailzhang
Copy link
Contributor Author

@ssnl I added a test to compare it with reduction='none'. Let me know what you think, thanks!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang ailzhang changed the title Fix kl_div default behavior Add new reduction mode in kl_div Nov 28, 2018
@ailzhang ailzhang removed the module: bc-breaking Related to a BC-breaking change label Nov 28, 2018
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

specifying either of those two args will override :attr:`reduction`. Default: 'mean'
'none' | 'batchmean' | 'sum' | 'mean'. 'none': no reduction will be applied,
'batchmean': the sum of the output will be divided by the number of
batches in the output, 'sum': the output will be summed, 'mean': the output will be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the number of batches" -> "the batch size" or "the number of elements in the input batch"

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225
Copy link
Contributor

yf225 commented Dec 3, 2018

@pytorchbot retest this please

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think batch_mean should be an enum element. It will also introduce weird behavior for the losses still defined in TH if used with batch_mean since if-statements there are not written with batch_mean in mind.

None, // Do not reduce
Mean, // (Possibly weighted) mean of losses
Sum, // Sum losses
BatchMean, // Mean over batches. = Sum / batchsize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just hack this up in python? It should only be a workaround for kl_div and doesn't make sense for other losses.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now wondering if reduction='none' is wrong.. But this generally LGTM.

reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
if reduction == 'mean':
warnings.warn("reduction=mean doesn't give the true kl divergence value. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to put quotes around mean and batchmean. Also, I think this can be worded more clearly like

reduction: "mean" divides the total loss by both the batch size and the support size. 
"batchmean" divides only by the batch size, and aligns with the KL divergence math definition. 
"mean" will be changed to behave the same as "batchmean" in the next major release.

'mean': the output will be divided by the number of elements in the output
Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
and in the meantime, specifying either of those two args will override :attr:`reduction`.
Note: `reduction='mean'` doesn't return the true kl divergence value, please use
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this a more obvious note using .. note::. Same for the module doc.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang ailzhang deleted the fix_kl branch December 7, 2018 00:03
@ezyang ezyang added this to the 1.0 milestone Apr 1, 2019
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Possible KL_loss bug on output dimension average.

6 participants