Skip to content

Commit

Permalink
PR #12228: [GPU] Fix hang with cudnn layer norm by moving build phase…
Browse files Browse the repository at this point in the history
… to Initialize()

Imported from GitHub PR openxla/xla#12228

The first time that a NormThunk is executed, it will build a cudnn execution plan. This build step can hang if a NCCL collective is running at the same time. To fix this, I've moved the build step to take place during thunk initialization. We only observe this hang when using cudnn 9.

Here's a backtrace from the hang that will be fixed:
```
Thread 585 (Thread 0x7fb9391ff640 (LWP 41364) "main.py"):
#0  0x00007fd3d17cffd9 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
#1  0x00007fd3d17da24f in pthread_rwlock_wrlock () from /lib/x86_64-linux-gnu/libc.so.6
#2  0x00007fd070967dfe in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007fd0709c928a in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f1970d76102 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#5  0x00007f1970f2c999 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#6  0x00007f1970a7d4ab in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#7  0x00007f1970d0a9cb in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#8  0x00007fce60b2a98c in cudnn::backend::ExecutionPlan::finalize_internal() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
#9  0x00007fce60aefbb1 in cudnn::backend::Descriptor::finalize() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
#10 0x00007fce60b15bec in cudnnBackendFinalize () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
#11 0x00007fd2521b8f39 in cudnn_frontend::ExecutionPlanBuilder_v8::build() () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#12 0x00007fd2521734ba in stream_executor::gpu::(anonymous namespace)::GetExecPlanFromHeuristics(cudnn_frontend::OperationGraph_v8&&, stream_executor::gpu::(anonymous namespace)::CudnnHandle const&, bool) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#13 0x00007fd25216ff9b in stream_executor::gpu::CudnnSupport::NormRunnerFromDesc(stream_executor::Stream*, stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormKind, double, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#14 0x00007fd24e36b88b in stream_executor::dnn::NormOp::RunnerFromAlgorithmDesc(stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#15 0x00007fd24e36ae37 in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}::operator()() const () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#16 0x00007fd24e36adbc in void absl::lts_20230802::base_internal::CallOnceImpl<stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}>(std::atomic<unsigned int>*, absl::lts_20230802::base_internal::SchedulingMode, stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}&&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#17 0x00007fd24e36a9bd in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#18 0x00007fd24e369d29 in xla::gpu::RunGpuNorm(xla::gpu::GpuNormConfig const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, stream_executor::DeviceMemoryBase const&, stream_executor::Stream*, xla::gpu::RunNormOptions) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#19 0x00007fd24e368be6 in xla::gpu::NormThunk::ExecuteOnStream(xla::gpu::Thunk::ExecuteParams const&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
```
Copybara import of the project:

--
f53533087ba1ddcf65ad7cc6268ee89de4690d15 by Trevor Morris <tmorris@nvidia.com>:

Fix hang with cudnn layer norm by moving cudnn init to Initialize()

Merging this change closes #12228

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12228 from trevor-m:tmorris-norm-init f53533087ba1ddcf65ad7cc6268ee89de4690d15
PiperOrigin-RevId: 633207155
  • Loading branch information
trevor-m authored and tensorflower-gardener committed May 13, 2024
1 parent bd8a712 commit a3dca26
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
18 changes: 2 additions & 16 deletions third_party/xla/xla/service/gpu/gpu_norm_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,8 @@ absl::Status RunGpuNorm(const gpu::GpuNormConfig& config,
se::Stream* stream, RunNormOptions options) {
se::dnn::LazyOpRunner<se::dnn::NormOp>* lazy_runner =
options.norm_runner->AsNormRunner();
std::optional<se::dnn::LazyOpRunner<se::dnn::NormOp>> local_runner;

TF_ASSIGN_OR_RETURN(se::dnn::NormKind kind,
GetDNNNormKindFromCudnnNormKind(config.kind));

se::dnn::NormOp::Config ln_config{kind,
config.epsilon,
config.x_descriptor,
config.scale_descriptor,
config.y_or_dx_descriptor,
config.bias_descriptor,
config.dy_descriptor,
config.expectation_descriptor,
config.norm_factor_descriptor,
config.dscale_descriptor,
config.dbias_descriptor};
TF_ASSIGN_OR_RETURN(se::dnn::NormOp::Config ln_config,
config.AsDnnNormOpConfig());
TF_ASSIGN_OR_RETURN(auto* runner,
lazy_runner->GetOrCreateRunner(ln_config, stream));

Expand Down
16 changes: 16 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_norm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ struct GpuNormConfig {
return config;
}

absl::StatusOr<se::dnn::NormOp::Config> AsDnnNormOpConfig() const {
TF_ASSIGN_OR_RETURN(se::dnn::NormKind norm_kind,
GetDNNNormKindFromCudnnNormKind(kind));
return se::dnn::NormOp::Config{norm_kind,
epsilon,
x_descriptor,
scale_descriptor,
y_or_dx_descriptor,
bias_descriptor,
dy_descriptor,
expectation_descriptor,
norm_factor_descriptor,
dscale_descriptor,
dbias_descriptor};
}

double epsilon;
CudnnNormKind kind;
se::dnn::AlgorithmDesc algorithm;
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/norm_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,14 @@ absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) {
return absl::OkStatus();
}

absl::Status NormThunk::Initialize(const InitializeParams& params) {
// Create the runner at initialization time to avoid hangs if we try to build
// the execution plan while a NCCL collective is running.
se::dnn::LazyOpRunner<se::dnn::NormOp>* lazy_runner =
GetOrCreateRunner(params.stream).AsNormRunner();
TF_ASSIGN_OR_RETURN(auto ln_config, config_.AsDnnNormOpConfig());
return lazy_runner->GetOrCreateRunner(ln_config, params.stream).status();
}

} // namespace gpu
} // namespace xla
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/runtime/norm_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class NormThunk : public Thunk {
NormThunk& operator=(const NormThunk&) = delete;

absl::Status ExecuteOnStream(const ExecuteParams& params) override;
absl::Status Initialize(const InitializeParams& params) override;

private:
BufferAllocation::Slice x_buffer_;
Expand Down

0 comments on commit a3dca26

Please sign in to comment.