Open
Description
Describe the issue
We would like to admit that following deprecation required to be completed.
https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
We eventually start seeing issue during the compilation process against product that are in development for NVIDIA.
Urgency
It need to be patched before June.
Target platform
Linux
Build script
build [DEBUG] - Command line arguments:
--build_dir /workspace/onnxruntime/build/Linux --config Release --skip_submodule_sync --parallel --build_shared_lib --compile_no_warning_as_error --build_dir /workspace/build --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES='"'"'75;80;86;89;90;100;120'"'"'' --cmake_extra_defines CMAKE_POLICY_VERSION_MINIMUM=3.5 --update --build --use_cuda --cuda_home /usr/local/cuda --cudnn_home /usr --use_tensorrt --use_tensorrt_builtin_parser --tensorrt_home /usr/src/tensorrt --allow_running_as_root --use_openvino CPU
Error / output
[ 80%] Building CUDA object CMakeFiles/onnxruntime_providers_cuda.dir/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu.o
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(98): error: namespace "cub" has no member "Max"
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(117): error: namespace "cub" has no member "Sum"
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum());
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(174): error: namespace "cub" has no member "Max"
const auto max = BlockReduce(tmp_storage).Reduce(input_data, cub::Max(), end);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(187): error: namespace "cub" has no member "Sum"
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(243): error: namespace "cub" has no member "Max"
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(257): error: namespace "cub" has no member "Sum"
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(346): error: namespace "cub" has no member "Max"
const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), total_sequence_length);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(360): error: namespace "cub" has no member "Sum"
const auto sum = BlockReduce(tmp_storage).Reduce(sum_thread_data_exp, cub::Sum(), TPB);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(444): error: namespace "cub" has no member "Max"
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), total_sequence_length);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(453): error: namespace "cub" has no member "Sum"
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), total_sequence_length);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(599): error: namespace "cub" has no member "Max"
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
^
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu(612): error: namespace "cub" has no member "Sum"
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
^
12 errors detected in the compilation of "/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu".
gmake[2]: *** [CMakeFiles/onnxruntime_providers_cuda.dir/build.make:5978: CMakeFiles/onnxruntime_providers_cuda.dir/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu.o] Error 2
gmake[2]: *** Waiting for unfinished jobs....
/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu(386): error: namespace "cub" has no member "Max"
int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
^
1 error detected in the compilation of "/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu".
gmake[2]: *** [CMakeFiles/onnxruntime_providers_cuda.dir/build.make:6023: CMakeFiles/onnxruntime_providers_cuda.dir/workspace/onnxruntime/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu.o] Error 2
Visual Studio Version
No response
GCC / Compiler Version
No response