Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU-CPU device mismatch error in util filter_dilated_rows #633

Conversation

tklausen
Copy link
Contributor

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Docs change / refactoring / dependency upgrade

Motivation and Context / Related issue

The function filter_dilated_rows in tensor_utils.py converts a tensor to an ndarray, modifies the ndarray, and converts the modified ndarray back to a tensor.

Bug:
If the original tensor is not on the CPU, the conversion to ndarray will fail because tensor.cpu() is not called.

File "opacus/utils/tensor_utils.py", line 328, in filter_dilated_rows
    tensor_np = tensor.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Fix:
This PR directly modifies the tensor without ever converting it to an ndarray. This fixes the bug and is more efficient than the original implementation.

How Has This Been Tested (if it applies)

Manually tested with the example provided in the function's DocString.

Also, filter_dilated_rows is called if the dilation of a 3d convolution is not 1. Thus, this function is implicitly tested by tests/grad_samples/conv3d_test.py.

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 26, 2024
@facebook-github-bot
Copy link
Contributor

@facebook-github-bot has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 32a465b.

@karthikprasad
Copy link
Contributor

Thanks for the fix @tklausen :)

@karthikprasad karthikprasad self-requested a review March 5, 2024 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants