Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cuda launch error in reflection_pad2d #56451

Closed
wants to merge 5 commits into from

Conversation

xwang233
Copy link
Collaborator

Fix #55222

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 20, 2021

💊 CI failures summary and remediations

As of commit 808dd2e (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test1 (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

AssertionError: False is not true : Scalars fai...ith rtol=1.3e-06 and atol=1e-05 is only 1.4278052!
======================================================================
FAIL [4.760s]: test_cudnn_multiple_threads_same_device (__main__.TestCuda)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 439, in wrapper
    fn(*args, **kwargs)
  File "test_cuda.py", line 2505, in test_cudnn_multiple_threads_same_device
    (2048 - test_iters) * (2048 - test_iters))
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 1371, in assertEqual
    super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
AssertionError: False is not true : Scalars failed to compare as equal! Comparing 1890625.0 and 1098304 gives a difference of 792321.0, but the allowed difference with rtol=1.3e-06 and atol=1e-05 is only 1.4278052!

----------------------------------------------------------------------
Ran 159 tests in 78.603s

FAILED (failures=1, skipped=67)

Generating XML reports...
Generated XML report: test-reports\python-unittest\test_cuda\TEST-TestCuda-20210421073255.xml
Generated XML report: test-reports\python-unittest\test_cuda\TEST-TestCudaComm-20210421073255.xml
Traceback (most recent call last):

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@xwang233 xwang233 requested a review from ngimel April 20, 2021 05:40
@xwang233 xwang233 requested a review from ptrblck April 20, 2021 05:44
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, small nit about legacy header.

for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));

dim3 grid_size(THCCeilDiv(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use cuda::ATenCeilDiv here, don't include legacy header

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 3ec6bf5.

krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Fix pytorch#55222

Pull Request resolved: pytorch#56451

Reviewed By: malfet

Differential Revision: D27912184

Pulled By: ngimel

fbshipit-source-id: 3fc80273c30a68a247289d3fb698f99b92931731
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RuntimeError: CUDA error: invalid configuration argument when using PyTorch code
4 participants