-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jit Error with CUDA and FP16 -- identifier "aten_add_flat__1" is undefined #47138
Comments
This looks like an issue with the variable uniquing rules in the TE compiler, I'll take a look. |
@erikwijmans I think I should have fixed this in #47229, at least I used your repro as the test case for it (thanks for the minimal repro btw!). Please let me know if the issue persists and I'll jump back on it. |
@nickgg Glad the repro was helpful! I built from scratch on the branch from the PR and this fixes it (both the repro and the full model). Thank you! |
…47448) Summary: Take 2 of this fix, I removed the repro from the issue which is a bit flaky due to parallelism. It broke on Windows but isn't specific to Windows or this fix, I think. I'll make sure all the tests pass this time (cc zou3519). Fixes an issue where fp16 scalars created by the registerizer could be referenced as floats - causing invalid conversions which would crash in the NVRTX compile. I also noticed that we were inserting patterns like float(half(float(X))) and added a pass to collapse those down inside the CudaHalfScalarRewriter. Fixes #47138 Pull Request resolved: #47448 Reviewed By: glaringlee Differential Revision: D24765070 Pulled By: nickgg fbshipit-source-id: 5297e647534d53657bef81f4798e8aa6a93d1fbd
I am still having this issue:
|
The fix for this wasn't cherry picked into 1.7.1. It should be in nightly tho. |
oh got it -- will it be included in 1.8? do you have a ballpark estimate of when that will be released? |
I am not Pytroch team, so I don't know the exact details, but since the fix is in master I assume it will be. Don't have a guess for the 1.8 release timeline tho. |
I have the same issue with, I have : pytorch 1.7.1 py3.8_cuda10.2.89_cudnn7.6.5_0 It fails consistently using the repro.py script from above |
Hello , any updates for this issue? Thank you in advance |
🐛 Bug
When running a scripted module with a cuda device and with fp16, I get the following error when computing the backwards pass:
To Reproduce
Steps to reproduce the behavior:
python repro.py
where repro has the following contents:I have minified this as much as I can (started with the actual module from my network and removed stuff that didn't cause the error)
Running this produces the following:
The exact number of iterations it runs for before erroring seems to be somewhat stochastic, but I have never seen it error on the first iteration and have only seen it error on the 2nd or 3rd.
I have also seen with slightly different variants of this that
aten_mul_flat__1
is undefined. I assume the root cause is the same, but thought I would point this out.Expected behavior
Does not crash
Environment
1.7.0-py3.6_cuda10.1.243_cudnn7.6.3_0
conda
,pip
, source): conda1.6.0-py3.6_cuda10.1.243_cudnn7.6.3_0
cc @gmagogsfm
The text was updated successfully, but these errors were encountered: