Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

call contiguous on BMM inputs for NT on CUDA #88108

Closed
wants to merge 2 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 31, 2022

Fixes #87713

BMM for cpu supports non-contiguous nested tensor inputs, while BMM for Cuda does not support currently non-contiguous inputs.

The derivative for BMM:

- name: bmm(Tensor self, Tensor mat2) -> Tensor
  self: grad.bmm(mat2.transpose(1, 2).conj())
  mat2: self.transpose(1, 2).conj().bmm(grad)
  result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t)

When calling backward it was impossible for this function to succeed since the inputs were always discontiguous, regardless of the user input. This adds contiguous calls to BMM_cuda implementation for nested tensors.

This was not caught by tests because grad_check is currently only done on CPU in test_nestedtensors. This PR updates the autograd test to also be run on GPU.

As a result I found one more issue with the backward for to_padded_tensor erroring instead of calling the generic version.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Oct 31, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88108

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 8 Pending

As of commit 7a607f8:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg requested review from mikaylagawarecki and cpuhrsch and removed request for mikaylagawarecki October 31, 2022 19:14
@drisspg drisspg force-pushed the call_contiguous_on_bmm_for_nt branch from e2542b6 to 7a607f8 Compare October 31, 2022 19:15
@drisspg drisspg added module: nestedtensor NestedTensor tag see issue #25032 release notes: nested tensor Changes that have a direct impact on nested tensors topic: bug fixes topic category labels Oct 31, 2022
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 31, 2022
@drisspg
Copy link
Contributor Author

drisspg commented Oct 31, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Fixes pytorch#87713

BMM for cpu supports  non-contiguous nested tensor inputs, while BMM for Cuda does not support currently non-contiguous inputs.

The derivative for BMM:
```
- name: bmm(Tensor self, Tensor mat2) -> Tensor
  self: grad.bmm(mat2.transpose(1, 2).conj())
  mat2: self.transpose(1, 2).conj().bmm(grad)
  result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t)
```

When calling backward it was impossible for this function to succeed since the inputs were always discontiguous, regardless of the user input.  This adds contiguous calls to BMM_cuda implementation for nested tensors.

This was not caught by tests because grad_check is currently only done on CPU in test_nestedtensors. This PR updates the autograd test to also be run on GPU.

As a result I found one more issue with the backward for to_padded_tensor erroring instead of calling the generic version.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: pytorch#88108
Approved by: https://github.com/cpuhrsch
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Fixes pytorch#87713

BMM for cpu supports  non-contiguous nested tensor inputs, while BMM for Cuda does not support currently non-contiguous inputs.

The derivative for BMM:
```
- name: bmm(Tensor self, Tensor mat2) -> Tensor
  self: grad.bmm(mat2.transpose(1, 2).conj())
  mat2: self.transpose(1, 2).conj().bmm(grad)
  result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t)
```

When calling backward it was impossible for this function to succeed since the inputs were always discontiguous, regardless of the user input.  This adds contiguous calls to BMM_cuda implementation for nested tensors.

This was not caught by tests because grad_check is currently only done on CPU in test_nestedtensors. This PR updates the autograd test to also be run on GPU.

As a result I found one more issue with the backward for to_padded_tensor erroring instead of calling the generic version.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: pytorch#88108
Approved by: https://github.com/cpuhrsch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: nestedtensor NestedTensor tag see issue #25032 release notes: cuda release notes category release notes: nested tensor Changes that have a direct impact on nested tensors topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Custom Autograd Functions Don't Work If Forward Pass Outputs a List of Tensors
3 participants