New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA] When ptxas do not know about an SM, fallback to the driver. #43888
Conversation
@@ -198,6 +198,42 @@ absl::optional<bool> CanShareBufferHint(const HloInstruction* user, | |||
return absl::nullopt; | |||
} | |||
|
|||
// Try to load ptx from files defined in the FLAGS. If successful, return true. | |||
bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diff show up stream. I didn't move MaybeLoadPtxFromFile.
I moved the function WarnIfBadDriverJITVersion outside the unnamed namespace.
@@ -415,7 +415,10 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult( | |||
"using $PATH.", | |||
hlo_module_config); | |||
} | |||
} else { | |||
} else if (maybe_cubin.status().code() != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA used to silently fallback to the driver when ptxas couldn't be found or compilation failed. The silent fallback behavior led to several bugs reported by users that were hard to reproduce and diagnose. We hence decided to turn ptxas issues into fatal errors (few people notice warnings in logs) and allow manual overwrite by the flag (--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found).
Have you considered to introduce a flag for the _ptxas_too_old case? I think that would be the better option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it isn't a real compilation failure. Just that PTXAS doesn't know a specific SM version.
Where can I have more information about the bug that you had that triggered this decision?
I'll think about your suggestion and come back about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick follow up:
For currently supported GPU, there is no change. The new fallback is only for new GPUs. So when an (current or old) container is used on a newer GPU, the fallback is used.
ce59dbc
to
40c7faf
Compare
I amended the commit as it had one debug leftover. |
@@ -113,6 +113,7 @@ cc_library( | |||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils", | |||
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl", | |||
"//tensorflow/compiler/xla/service/gpu:launch_dimensions", | |||
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra dependency breaks tensorflow/compiler/xla/service/mlir_gpu/tests
. The linking step will register a compiler and these tests fail with Check failed: factories->find(platform_id) == factories->end() Compiler factory already registered for platform
.
Please consider moving WarnIfBadDriverJITVersion
into a different place. I believe asm_compiler.cc would work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I amended the commit to just remove this line. This fix the mlir_gpu tests. //tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl
is already included and it is enough. As it was already included, I do not see value in moving that function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, the dependency to nvptx_compiler_impl is enough and the other one was redundant/wrong.
40c7faf
to
7544886
Compare
Currently XLA always use ptxas. If a user have an old container, but a newer GPU, ptxas won't know its SM version.
In that case, instead of erroring, fallback to the driver to compile instead of PTXAS.
It won't have all optimization, but it will be working.