New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Value of torch.backends.cudnn.benchmark Baked into JIT-Traced Modules ( 150x slowdown on ConvTranspose2d() ) [jit] [libtorch] [cudnn] #18776
Comments
Yea in the current implementation it's expected behavior for the loaded module using a flag as constant set in trace phase. We could potentially add a warning for it. |
If you adjust it to not trace |
I think the way the code is currently structured, there are a lot of call sites of |
@ezyang In tracing we deliberately not to trace the function wraps around
I'm not very familiar with the reason behind it though. |
we need the backend flag to compute backward correctly. |
@soumith we don't. I wrote symbolic AD for batch norm which allows for multiple backends. I think the idea was that someone might have code like: torch.backends.cudnn.enabled = False
some_op() # cuDNN bug
torch.backends.cudnn.enabled = True And we wanted to preserve this semantics. I think we should save those flags when tracing. |
Currently tracing convolution operations wraps to internal implementation _convolution function. As a result 'benchmark' and other flags are hard wired. Changes to these context flags at later stages are ignored. The proposed fix is to remove 'convolution' from DONT_RECORD_TRACE. Note all convolution operations end up calling convolution function so this fix applies to all convolution operations. [ghstack-poisoned]
Currently tracing convolution operations wraps to internal implementation _convolution function. As a result 'benchmark' and other flags are hard wired. Changes to these context flags at later stages are ignored. The proposed fix is to remove 'convolution' from DONT_RECORD_TRACE. Note all convolution operations end up calling convolution function so this fix applies to all convolution operations. ghstack-source-id: 8368ea9b953e01f016cc1405687778706363136e Pull Request resolved: #28027
Can we distinguish the case where the user explicitly set the flag during the run of the model, from the case where the flag is set outside the model? In specific, I mean, we should trace/script In such design torch.backends.cudnn.enabled = False
@script
def model(x):
return F.convolution(x, .....) will respect whatever the flag is set at the context where the model is executed, while @script
def model(x):
torch.backends.cudnn.enabled = False
return F.convolution(x, .....) will always run without cuDNN regardless of the context. |
馃悰 Bug
If you trace a module with
torch.jit.trace(...)
and load that script module in C++ via LibTorch, the resulting behavior in C++ depends on whether or not thetorch.backends.cudnn.benchmark
flag was set. Calls toat::globalContext().setBenchmarkCuDNN(true/false)
from the C++ API at runtime appear to have no effect.To Reproduce
NOTE: I was not able to verify this issue still exists on the latest nightly (20190402) because it appears the latest nightly (at least on Windows) cannot run JIT-traced models. Even the simplest model gives the following error:
python test.py 0
orpython test.py 1
a) Average time per call. I see ~0.8ms in the python script and either ~0.8 or ~120ms in C++ depending on the flag used in python. In either case, C++ sets benchmarking ON. (GTX 1080)
b) Kernel run by CuDNN. w/either setting of the flag, the python code runs
cudnn::detail::dgrad_engine<...>
. With the flag ON, it runscudnn::detail::dgrad2d_alg1_1<...>
once (taking ~120ms) and then chooses the fasterdgrad_engine
. If the flag was ON in python, C++ also choosesdgrad_engine
but if the flag was OFF in python, it always choosesdgrad2d_alg1_1
regardless of the flag setting in C++.I observed the choice of kernel using
nvprof python test.py 0/1
.Python Script (
test.py
):C++ Code:
Expected Behavior
I would expect that either: 1) the C++ setting of
at::globalContext().setBenchmarkCuDNN(true)
should be respected (choosing the correct algorithm) or 2) at least print a warning that it is being overridden by the value of the flag at trace time.Additional Info
I printed the JIT graphs generated with benchmarking ON/OFF and got the following with the flag OFF:
The only change when the flag is ON is that register %17 is 1 instead of 0. I suppose this is where the "hardcoding" of the flag might be happening?
Environment
Python code was run on Linux, C++ code was run on Windows
conda
,pip
, source): conda (pytorch-nightly)cc @suo
The text was updated successfully, but these errors were encountered: