Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement NLLLossNd #4035

Merged
merged 3 commits into from Dec 18, 2017
Merged

Implement NLLLossNd #4035

merged 3 commits into from Dec 18, 2017

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Dec 5, 2017

Needed for #3556

I'm not sure this is the best way to implement because the .contiguous() calls might be slow.

One alternative way to implement this is to copy and modify gather. Without any of the extra keyword modifiers, with reduce=False, the following is equivalent to NLLLossNd:

def nlllossNd(input, target):
    target = target.unsqueeze(1)
    out = torch.gather(input, 1, target)
    return out.squeeze(1)

I tried benchmarking this against what I have right now (this diff that uses .contiguous() calls and NLLLoss2d) and using gather is around 2x slower, even for non-contiguous inputs, so I went with this approach.

Test Plan

Unit tests for NLLLossNd with a NLLLossNd reference function

@@ -322,7 +325,7 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True):
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
'NLLLoss2d': nllloss2d_reference,
'NLLLossNd': nlllossNd_reference,

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@zou3519
Copy link
Contributor Author

zou3519 commented Dec 6, 2017

@pytorchbot retest this please

@pietern
Copy link
Contributor

pietern commented Dec 7, 2017

There was some CI maintenance happening this morning -- retriggering build.

@pytorchbot retest this please

@soumith soumith merged commit 30e6898 into pytorch:master Dec 18, 2017
@zou3519 zou3519 deleted the nlllossNd branch January 3, 2018 19:58
@soumith soumith added the 0.3.1 label Feb 4, 2018
soumith pushed a commit that referenced this pull request Feb 7, 2018
* Implement NLLLossNd

* Fix tests and typos

* Fix tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants