-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Throw error if torch.set_deterministic(True)
is called with nondeterministic CuBLAS config
#41377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Throw error if torch.set_deterministic(True)
is called with nondeterministic CuBLAS config
#41377
Conversation
💊 CI failures summary and remediationsAs of commit 3f4c77c (more details on the Dr. CI page):
🚧 3 fixed upstream failures:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
740dcac
to
6f22888
Compare
99a3dff
to
ad08b74
Compare
The reason that I was only calling So I just need to create a function |
082d01c
to
5a047d0
Compare
5a047d0
to
1ff8d99
Compare
result.resize_({ self.size(0), mat2.size(1) }); | ||
return addmm_out_cuda_impl(result, result, self, mat2, 0, 1); | ||
} | ||
|
||
Tensor mm_cuda(const Tensor& self, const Tensor& mat2) { | ||
globalContext().alertCuBLASConfigNotDeterministic(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially, I only added alerts in the internal CuBLAS wrapper functions, in CUDABlas.cpp and THCBlas.cu. But then when I created and ran tests for a handful of the torch operations that use these functions (like torch.mm
, torch.dot
, etc), I was getting some CUDA memory access errors when I ran the tests back to back.
The problem seemed to be that the error was being thrown halfway through some operations and memory could sometimes be left in an unsafe state. So I had to call the alert function here instead, before the operations have a chance to do anything to the memory.
I think we should keep the alerts in the CuBLAS wrappers though, so that we automatically have error coverage over every operation that calls them.
But I wonder if I should continue to add alerts and tests for each existing torch operation that calls the CuBLAS wrappers. I feel like that might be overkill, but not sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oof, that's not very nice. In principle it should be safe to unwind the stack at any given point if we are using proper C++ destructors, I wonder if something is still using legacy behavior. It's possible this is related specifically to the th_
wrappers. This is probably expedient for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll pin down exactly what the issue was.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was simpler than I assumed. at::native::dot_cuda()
changes the CuBLAS pointer mode with cublasSetPointerMode()
before calling at::cuda::blas::dot()
. It restores the previous setting after the call, but only if the call doesn't throw an error. So I had to just put the at::cuda::blas::dot()
call in a try-catch. If there's an error, it will now restore the pointer mode and rethrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's even simpler than that, dot is always deterministic, regardless of the workplace setting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"gemv and gemm" meaning only the non-batched versions? Or are the batched ones affected too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only non-batched versions, batched are not affected. What if batch is 1, you might ask? Well, I don't know, I'm just relaying the message. @ptrblck ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Batched versions of gemv and gemm with a batch size of 1 can be non-deterministic, so we would need to disallow them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for letting me know! So I'll remove the alerts for everything but gemv and gemm, batched and unbatched. I think there are a couple questions to answer about how to handle the batch size 1 case though. It seems like these are our options:
- Alert batched gemv and gemm always
- Alert batched gemv and gemm only if batch size is 1
a. Add a message to the alert explaining that it's nondeterministic because the batch size is 1
b. Don't add the message
I think 2.a. might be the best, but I'm not sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the errors for ger and dot, so now only mm, mv, and batched multiplies throw errors. This is option 1 from my above comment. Please let me know if it would be better to actually check the batch size and provide more detail in the error message.
65dd62a
to
2cc3b42
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
It will be good to also get a look from @ngimel |
ea76ce6
to
f2211c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
f29abaa
to
a731bc9
Compare
Thank you for the sleuthing, this seems like a good outcome to me. Maybe we should file an nvbug, not sure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Looks like |
if os.environ.get(cublas_var_name) is not None: | ||
del os.environ[cublas_var_name] | ||
else: | ||
os.environ[cublas_var_name] = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the idiomatic way to set environment variables in a subprocess, as manual edits to os.environ
here are process global and will affect all other tests in the process. Instead, you can simply pass the desired environment variables of the subprocess call directly to subprocess itself.
def test_case_info(): | ||
return 'function "%s", config "%s"' % (fn_name, '' if config is None else config) | ||
|
||
# Wait for each process to finish and check for correct error behavior |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you... actually want to run each process in parallel? Running them serially seems a lot safer.
processes.append((p, fn_name, config, should_throw_error)) | ||
|
||
def test_case_info(): | ||
return 'function "%s", config "%s"' % (fn_name, '' if config is None else config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use more modern formatting here, function "{}".format(fn_name)
or even better f-strings f'function "{fn_name}"'
# It would have been preferable to use the `multiprocessing` module to avoid having | ||
# to execute code from a string, but that caused issues in Windows | ||
# https://github.com/pytorch/pytorch/pull/41377#issuecomment-666641223 | ||
p = subprocess.Popen( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was expecting to see check_output
here, which would have reduced a lot of the Popen boilerplate here. Here is one way you could make this happen: catch the error throw in the subprocess itself, and then convert that into a success condition (and raise errors otherwise). Then, you only need to test for the exit code in the parent process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to move things along
Summary: Adds an RAII guard for `cublasSetPointerMode()`. Updates `dot_cuda` to use the guard, rather than exception catching. Addresses this comment: #41377 (comment) Pull Request resolved: #42639 Reviewed By: malfet Differential Revision: D22969985 Pulled By: ezyang fbshipit-source-id: b05c35d1884bb890f8767d6a4ef8b4724a329471
…st (#42627) Summary: Addresses some comments that were left unaddressed after PR #41377 was merged: * Use `check_output` instead of `Popen` to run each subprocess sequentially * Use f-strings rather than old python format string style * Provide environment variables to subprocess through the `env` kwarg * Check for correct error behavior inside the subprocess, and raise another error if incorrect. Then the main process fails the test if any error is raised Pull Request resolved: #42627 Reviewed By: malfet Differential Revision: D22969231 Pulled By: ezyang fbshipit-source-id: 38d5f3f0d641c1590a93541a5e14d90c2e20acec
Summary: Follow up to pytorch#41377 to update the error message to match the removed arguments Pull Request resolved: pytorch#46397 Reviewed By: malfet Differential Revision: D24336009 Pulled By: albanD fbshipit-source-id: b9bf2f9ef7fd2ae622c4079384afc93e9c473f47
For CUDA >= 10.2, the
CUBLAS_WORKSPACE_CONFIG
environment variable must be set to either:4096:8
or:16:8
to ensure deterministic CUDA stream usage. This PR adds some logic insidetorch.set_deterministic()
to raise an error if this environment variable is not set properly and CUDA >= 10.2.Issue #15359