From 08f0ce0f919101f1bdcde49dee362811d4ee58d3 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 9 Oct 2025 21:37:23 -0700 Subject: [PATCH 01/10] [aoti-et] Enable multimodal runner for Voxtral on CUDA --- .github/workflows/cuda.yml | 154 ++++++++++++++---- backends/aoti/common_shims.cpp | 45 ++++- backends/cuda/cuda_backend.py | 2 + backends/cuda/runtime/cuda_backend.cpp | 9 +- examples/models/voxtral/CMakeLists.txt | 7 + extension/llm/runner/multimodal_prefiller.cpp | 35 +++- extension/llm/runner/util.h | 37 +++++ tools/cmake/executorch-config.cmake | 1 + 8 files changed, 256 insertions(+), 34 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 8dbbb254ac3..b49db641062 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -87,8 +87,8 @@ jobs: export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda - test-voxtral-cuda-e2e: - name: test-voxtral-cuda-e2e + export-voxtral-cuda-artifact: + name: export-voxtral-cuda-artifact uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write @@ -104,6 +104,7 @@ jobs: gpu-arch-version: 12.6 use-custom-docker-registry: false submodules: recursive + upload-artifact: voxtral-cuda-export ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | set -eux @@ -118,6 +119,7 @@ jobs: OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} pip install mistral-common librosa + pip list echo "::endgroup::" echo "::group::Export Voxtral" @@ -129,9 +131,58 @@ jobs: --device cuda \ --max_seq_len 1024 \ --output_dir ./ + python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --stack_output \ + --max_audio_len 300 \ + --output_file voxtral_preprocessor.pte + + test -f model.pte + test -f aoti_cuda_blob.ptd + test -f voxtral_preprocessor.pte echo "::endgroup::" - echo "::group::Build Voxtral Runner" + echo "::group::Store Voxtral Artifacts" + mkdir -p "${RUNNER_ARTIFACT_DIR}" + cp model.pte "${RUNNER_ARTIFACT_DIR}/" + cp aoti_cuda_blob.ptd "${RUNNER_ARTIFACT_DIR}/" + cp voxtral_preprocessor.pte "${RUNNER_ARTIFACT_DIR}/" + ls -al "${RUNNER_ARTIFACT_DIR}" + echo "::endgroup::" + + benchmark-voxtral-cuda: + name: benchmark-voxtral-cuda + needs: export-voxtral-cuda-artifact + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: 12.6 + use-custom-docker-registry: false + submodules: recursive + download-artifact: voxtral-cuda-export + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + echo "::group::Setup ExecuTorch Requirements" + CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh + pip list + echo "::endgroup::" + + echo "::group::Prepare Voxtral Artifacts" + cp "${RUNNER_ARTIFACT_DIR}/model.pte" . + cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" . + ls -al model.pte aoti_cuda_blob.ptd + echo "::endgroup::" + + echo "::group::Build Voxtral Benchmark" cmake -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_CUDA=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ @@ -142,31 +193,76 @@ jobs: cmake --build cmake-out -j$(( $(nproc) - 1 )) --target voxtral_runner echo "::endgroup::" - echo "::group::Run Voxtral Runner" - # Capture output and allow exit code 139 if we have the expected printout - set +e + echo "::group::Run Voxtral Benchmark" + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH - OUTPUT=$(cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd 2>&1) - EXIT_CODE=$? - set -e - - echo "$OUTPUT" - - # Check if the output contains "Run latency (ms):" - if echo "$OUTPUT" | grep -q "Run latency (ms):"; then - echo "Found expected output: 'Run latency (ms):'" - if [ $EXIT_CODE -eq 139 ]; then - echo "Exit code 139 (segfault) detected, but passing since we have the expected output" - exit 0 - elif [ $EXIT_CODE -ne 0 ]; then - echo "Unexpected exit code: $EXIT_CODE" - exit $EXIT_CODE - else - echo "Command succeeded with exit code 0" - exit 0 - fi - else - echo "Expected output 'Run latency (ms):' not found in output" - exit 1 - fi + cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd + + echo "::endgroup::" + + test-voxtral-cuda-e2e: + name: test-voxtral-cuda-e2e + needs: export-voxtral-cuda-artifact + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: 12.6 + use-custom-docker-registry: false + submodules: recursive + download-artifact: voxtral-cuda-export + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + echo "::group::Setup ExecuTorch Requirements" + CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh + pip list echo "::endgroup::" + + echo "::group::Prepare Voxtral Artifacts" + cp "${RUNNER_ARTIFACT_DIR}/model.pte" . + cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" . + cp "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" . + TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main/tekken.json" + curl -L $TOKENIZER_URL -o tekken.json + ls -al model.pte aoti_cuda_blob.ptd voxtral_preprocessor.pte tekken.json + echo "::endgroup::" + + echo "::group::Download Test Audio File" + AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav" + curl -L $AUDIO_URL -o dancing.wav + echo "::endgroup::" + + echo "::group::Build Voxtral Runner" + cmake --preset llm \ + -DEXECUTORCH_BUILD_CUDA=ON \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -Bcmake-out -S. + cmake --build cmake-out -j$(( $(nproc) - 1 )) --target install --config Release + + cmake -DEXECUTORCH_BUILD_CUDA=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -Sexamples/models/voxtral \ + -Bcmake-out/examples/models/voxtral/ + cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release + echo "::endgroup::" + + echo "::group::Run Voxtral Runner" + + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH + cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path model.pte \ + --data_path aoti_cuda_blob.ptd \ + --tokenizer_path tekken.json \ + --audio_path dancing.wav \ + --processor_path voxtral_preprocessor.pte \ + --temperature 0 + echo "::endgroup::" \ No newline at end of file diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abc83779443..4d15f870a41 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -51,13 +51,34 @@ AOTITorchError aoti_torch_get_storage_offset( AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) { auto it = internal::tensor_to_strides.find(tensor); + bool needs_update = false; + if (it == internal::tensor_to_strides.end()) { + needs_update = true; + } else { + // Check if cached values are still valid + auto tensor_strides = tensor->strides(); + if (it->second.size() != static_cast(tensor->dim())) { + needs_update = true; + } else { + for (int i = 0; i < tensor->dim(); i++) { + if (it->second[i] != tensor_strides[i]) { + needs_update = true; + break; + } + } + } + } + + if (needs_update) { std::vector strides(tensor->dim()); auto tensor_strides = tensor->strides(); for (int i = 0; i < tensor->dim(); i++) { strides[i] = tensor_strides[i]; } - it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first; + it = + internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides)) + .first; } // For 0D tensors, data() returns nullptr on empty vectors, but we need to @@ -80,13 +101,33 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) { AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) { auto it = internal::tensor_to_sizes.find(tensor); + bool needs_update = false; + if (it == internal::tensor_to_sizes.end()) { + needs_update = true; + } else { + // Check if cached values are still valid + auto tensor_sizes = tensor->sizes(); + if (it->second.size() != static_cast(tensor->dim())) { + needs_update = true; + } else { + for (int i = 0; i < tensor->dim(); i++) { + if (it->second[i] != tensor_sizes[i]) { + needs_update = true; + break; + } + } + } + } + + if (needs_update) { std::vector sizes(tensor->dim()); auto tensor_sizes = tensor->sizes(); for (int i = 0; i < tensor->dim(); i++) { sizes[i] = tensor_sizes[i]; } - it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first; + it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes)) + .first; } // For 0D tensors, data() returns nullptr on empty vectors, but we need to diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 8ed8cdefbb1..795fdb598e2 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -135,6 +135,8 @@ def preprocess( "aot_inductor.link_libtorch": False, # Package model constants and other generated files directly in the shared object (.so) file "aot_inductor.package_constants_in_so": True, + # Enable debug mode if the DEBUG environment variable is set + "aot_inductor.debug_compile": os.environ.get("DEBUG") == "1", # Enable maximum automatic tuning for optimal performance "max_autotune": True, # Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 58ab54e1aac..10a71b267ea 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -165,6 +165,13 @@ class ET_EXPERIMENTAL CudaBackend final Span args) const override { AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + // Need to re-register all the symbols from the so_handle hosted by this + // CudaBackend instance. The reason is that these symbols are + // static/singleton across the whole process. When we share multiple methods + // (meaning multiple so_handle) in the same process, we need to re-register + // the symbols from the so_handle that is being used in this execution. + register_shared_library_functions(handle->so_handle); + size_t n_inputs; AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs); @@ -223,7 +230,6 @@ class ET_EXPERIMENTAL CudaBackend final "Failed to copy input %d from CPU to GPU", i); } - ET_LOG(Info, "Inputs copied to GPU"); // Process output tensors: create GPU counterparts for ExecuTorch CPU // tensors for (int i = 0; i < n_outputs; i++) { @@ -253,7 +259,6 @@ class ET_EXPERIMENTAL CudaBackend final gpu_outputs[i] = gpu_output_handle; } - ET_LOG(Info, "Outputs created on GPU"); // Run AOTI container with GPU tensors AOTIRuntimeError error = AOTInductorModelContainerRun( handle->container_handle, diff --git a/examples/models/voxtral/CMakeLists.txt b/examples/models/voxtral/CMakeLists.txt index 85c6a13e0ff..3995f5533e6 100644 --- a/examples/models/voxtral/CMakeLists.txt +++ b/examples/models/voxtral/CMakeLists.txt @@ -86,6 +86,13 @@ list( extension_flat_tensor ) +# Link CUDA backend +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda) + executorch_target_link_options_shared_lib(aoti_cuda) +endif() + # Add tokenizers list(APPEND link_libraries tokenizers::tokenizers) diff --git a/extension/llm/runner/multimodal_prefiller.cpp b/extension/llm/runner/multimodal_prefiller.cpp index 2c83df24f55..e77053f8747 100644 --- a/extension/llm/runner/multimodal_prefiller.cpp +++ b/extension/llm/runner/multimodal_prefiller.cpp @@ -93,14 +93,47 @@ Result MultimodalPrefiller::prefill( } else if (input.is_audio()) { Audio audio = input.get_audio(); - // Use Audio::toTensor() for tensor creation + auto method_meta = ET_UNWRAP( + module_->method_meta(kAudioEncoderMethod), + "Failed to get method_meta for %s", + kAudioEncoderMethod); + + ET_CHECK_OR_RETURN_ERROR( + method_meta.num_inputs() > 0, + InvalidArgument, + "Audio encoder should have at least 1 input"); + auto input_meta = ET_UNWRAP( + method_meta.input_tensor_meta(0), + "Cannot get input tensor meta at index 0"); + auto expected_dtype = input_meta.scalar_type(); + + // Create tensor with original dtype auto audio_tensor = ET_UNWRAP(audio.toTensor(), "Failed to convert audio to tensor"); + + // Convert to expected dtype if needed + if (audio_tensor->scalar_type() != expected_dtype) { + if (expected_dtype == ::executorch::aten::ScalarType::BFloat16) { + // Convert to bfloat16 + audio_tensor = ET_UNWRAP( + convert_to_bfloat16(audio_tensor), + "Failed to convert audio tensor to bfloat16"); + } else { + ET_LOG( + Error, + "Unsupported audio encoder input dtype: %s. Expecting %s", + ::executorch::runtime::toString(audio_tensor->scalar_type()), + ::executorch::runtime::toString(expected_dtype)); + return ::executorch::runtime::Error::NotSupported; + } + } + ET_LOG( Info, "Audio tensor dim: %zu, dtype: %s", audio_tensor->dim(), ::executorch::runtime::toString(audio_tensor->scalar_type())); + // Run audio encoder auto audio_encoder_result = module_->execute(kAudioEncoderMethod, audio_tensor); diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 8fb245107ab..73b9963b28a 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -141,6 +141,43 @@ inline runtime::Result populate_start_pos_or_cache_position( } } +/** + * Helper function to convert a float tensor to bfloat16. + * Creates a new tensor with bfloat16 dtype and copies/converts the data. + */ +::executorch::runtime::Result<::executorch::extension::TensorPtr> +convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) { + ET_CHECK_OR_RETURN_ERROR( + src_tensor->scalar_type() == ::executorch::aten::ScalarType::Float, + InvalidArgument, + "BFloat16 conversion only supported from Float source data"); + + size_t num_elements = src_tensor->numel(); + auto sizes = src_tensor->sizes(); + + // Allocate memory for bfloat16 data + auto* bf16_data = new uint16_t[num_elements]; + const float* float_data = src_tensor->const_data_ptr(); + + // Convert float to bfloat16 + for (size_t i = 0; i < num_elements; ++i) { + // bfloat16 is the upper 16 bits of float32 + uint32_t float_bits; + std::memcpy(&float_bits, &float_data[i], sizeof(float)); + + // Rounding: add 0x7FFF to round to nearest even + uint32_t rounding_bias = 0x7FFF + ((float_bits >> 16) & 1); + bf16_data[i] = static_cast((float_bits + rounding_bias) >> 16); + } + + // Create tensor with deleter to free allocated memory + return ::executorch::extension::from_blob( + bf16_data, + {sizes.begin(), sizes.end()}, + ::executorch::aten::ScalarType::BFloat16, + [](void* ptr) { delete[] static_cast(ptr); }); +} + } // namespace llm } // namespace extension } // namespace executorch diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index 6c27e8ba616..3df8e947459 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -53,6 +53,7 @@ set(EXECUTORCH_FOUND ON) include("${CMAKE_CURRENT_LIST_DIR}/ExecuTorchTargets.cmake") set(optional_lib_list + aoti_cuda flatccrt etdump bundled_program From 0f1659aeb8ade51e74cc1aa7a59b7d2b08d363f0 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 9 Oct 2025 21:44:20 -0700 Subject: [PATCH 02/10] Chcek output --- .github/workflows/cuda.yml | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index b49db641062..40e05fb91db 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -256,13 +256,27 @@ jobs: echo "::endgroup::" echo "::group::Run Voxtral Runner" - + set +e export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH - cmake-out/examples/models/voxtral/voxtral_runner \ + OUTPUT=$(cmake-out/examples/models/voxtral/voxtral_runner \ --model_path model.pte \ --data_path aoti_cuda_blob.ptd \ --tokenizer_path tekken.json \ --audio_path dancing.wav \ --processor_path voxtral_preprocessor.pte \ - --temperature 0 - echo "::endgroup::" \ No newline at end of file + --temperature 0 2>&1) + EXIT_CODE=$? + set -e + + echo "$OUTPUT" + + if ! echo "$OUTPUT" | grep -iq "dancing"; then + echo "Expected output 'dancing' not found in output" + exit 1 + fi + + if [ $EXIT_CODE -ne 0 ]; then + echo "Unexpected exit code: $EXIT_CODE" + exit $EXIT_CODE + fi + echo "::endgroup::" From 18088240df22600a9af1ecd9767c68dc43a960fc Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 00:17:33 -0700 Subject: [PATCH 03/10] Address comments --- backends/aoti/common_shims.cpp | 70 ++++--------------- backends/cuda/runtime/cuda_backend.cpp | 3 +- extension/llm/runner/multimodal_prefiller.cpp | 12 ++-- extension/llm/runner/util.h | 2 +- 4 files changed, 22 insertions(+), 65 deletions(-) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 4d15f870a41..73cba6f11e7 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -50,36 +50,14 @@ AOTITorchError aoti_torch_get_storage_offset( } AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) { - auto it = internal::tensor_to_strides.find(tensor); - bool needs_update = false; - - if (it == internal::tensor_to_strides.end()) { - needs_update = true; - } else { - // Check if cached values are still valid - auto tensor_strides = tensor->strides(); - if (it->second.size() != static_cast(tensor->dim())) { - needs_update = true; - } else { - for (int i = 0; i < tensor->dim(); i++) { - if (it->second[i] != tensor_strides[i]) { - needs_update = true; - break; - } - } - } - } - - if (needs_update) { - std::vector strides(tensor->dim()); - auto tensor_strides = tensor->strides(); - for (int i = 0; i < tensor->dim(); i++) { - strides[i] = tensor_strides[i]; - } - it = - internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides)) - .first; + std::vector strides(tensor->dim()); + auto tensor_strides = tensor->strides(); + for (ssize_t i = 0; i < tensor->dim(); i++) { + strides[i] = static_cast(tensor_strides[i]); } + auto it = + internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides)) + .first; // For 0D tensors, data() returns nullptr on empty vectors, but we need to // return a valid pointer @@ -100,35 +78,13 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) { } AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) { - auto it = internal::tensor_to_sizes.find(tensor); - bool needs_update = false; - - if (it == internal::tensor_to_sizes.end()) { - needs_update = true; - } else { - // Check if cached values are still valid - auto tensor_sizes = tensor->sizes(); - if (it->second.size() != static_cast(tensor->dim())) { - needs_update = true; - } else { - for (int i = 0; i < tensor->dim(); i++) { - if (it->second[i] != tensor_sizes[i]) { - needs_update = true; - break; - } - } - } - } - - if (needs_update) { - std::vector sizes(tensor->dim()); - auto tensor_sizes = tensor->sizes(); - for (int i = 0; i < tensor->dim(); i++) { - sizes[i] = tensor_sizes[i]; - } - it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes)) - .first; + std::vector sizes(tensor->dim()); + auto tensor_sizes = tensor->sizes(); + for (ssize_t i = 0; i < tensor->dim(); i++) { + sizes[i] = static_cast(tensor_sizes[i]); } + auto it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes)) + .first; // For 0D tensors, data() returns nullptr on empty vectors, but we need to // return a valid pointer diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 10a71b267ea..805c54ff55c 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -170,7 +170,8 @@ class ET_EXPERIMENTAL CudaBackend final // static/singleton across the whole process. When we share multiple methods // (meaning multiple so_handle) in the same process, we need to re-register // the symbols from the so_handle that is being used in this execution. - register_shared_library_functions(handle->so_handle); + ET_CHECK_OK_OR_RETURN_ERROR( + register_shared_library_functions(handle->so_handle)); size_t n_inputs; AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs); diff --git a/extension/llm/runner/multimodal_prefiller.cpp b/extension/llm/runner/multimodal_prefiller.cpp index e77053f8747..8ab315d08c0 100644 --- a/extension/llm/runner/multimodal_prefiller.cpp +++ b/extension/llm/runner/multimodal_prefiller.cpp @@ -67,11 +67,11 @@ Result MultimodalPrefiller::prefill( InvalidArgument, "Model expects uint8_t image data, but image has float data."); } else { - ET_LOG( - Error, + ET_CHECK_OR_RETURN_ERROR( + false, + NotSupported, "Unsupported image encoder input dtype: %s", ::executorch::runtime::toString(expected_dtype)); - return ::executorch::runtime::Error::NotSupported; } // The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D @@ -119,12 +119,12 @@ Result MultimodalPrefiller::prefill( convert_to_bfloat16(audio_tensor), "Failed to convert audio tensor to bfloat16"); } else { - ET_LOG( - Error, + ET_CHECK_OR_RETURN_ERROR( + false, + NotSupported, "Unsupported audio encoder input dtype: %s. Expecting %s", ::executorch::runtime::toString(audio_tensor->scalar_type()), ::executorch::runtime::toString(expected_dtype)); - return ::executorch::runtime::Error::NotSupported; } } diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 73b9963b28a..6c8406f81fc 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -145,7 +145,7 @@ inline runtime::Result populate_start_pos_or_cache_position( * Helper function to convert a float tensor to bfloat16. * Creates a new tensor with bfloat16 dtype and copies/converts the data. */ -::executorch::runtime::Result<::executorch::extension::TensorPtr> +inline ::executorch::runtime::Result<::executorch::extension::TensorPtr> convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) { ET_CHECK_OR_RETURN_ERROR( src_tensor->scalar_type() == ::executorch::aten::ScalarType::Float, From a7c55b14042a91cb68b54a8916e120d7ea7ef00a Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 09:47:35 -0700 Subject: [PATCH 04/10] Check for poem --- .github/workflows/cuda.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 40e05fb91db..c1b22e692ab 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -194,10 +194,10 @@ jobs: echo "::endgroup::" echo "::group::Run Voxtral Benchmark" - + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd - + echo "::endgroup::" test-voxtral-cuda-e2e: @@ -234,10 +234,10 @@ jobs: curl -L $TOKENIZER_URL -o tekken.json ls -al model.pte aoti_cuda_blob.ptd voxtral_preprocessor.pte tekken.json echo "::endgroup::" - + echo "::group::Download Test Audio File" AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav" - curl -L $AUDIO_URL -o dancing.wav + curl -L $AUDIO_URL -o poem.wav echo "::endgroup::" echo "::group::Build Voxtral Runner" @@ -262,7 +262,7 @@ jobs: --model_path model.pte \ --data_path aoti_cuda_blob.ptd \ --tokenizer_path tekken.json \ - --audio_path dancing.wav \ + --audio_path poem.wav \ --processor_path voxtral_preprocessor.pte \ --temperature 0 2>&1) EXIT_CODE=$? @@ -270,8 +270,8 @@ jobs: echo "$OUTPUT" - if ! echo "$OUTPUT" | grep -iq "dancing"; then - echo "Expected output 'dancing' not found in output" + if ! echo "$OUTPUT" | grep -iq "poem"; then + echo "Expected output 'poem' not found in output" exit 1 fi From be5d187680a2b1df8d15eac1db471bf00d745c14 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 10:17:40 -0700 Subject: [PATCH 05/10] Remove debug config --- backends/cuda/cuda_backend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 795fdb598e2..8ed8cdefbb1 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -135,8 +135,6 @@ def preprocess( "aot_inductor.link_libtorch": False, # Package model constants and other generated files directly in the shared object (.so) file "aot_inductor.package_constants_in_so": True, - # Enable debug mode if the DEBUG environment variable is set - "aot_inductor.debug_compile": os.environ.get("DEBUG") == "1", # Enable maximum automatic tuning for optimal performance "max_autotune": True, # Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch From fb4940ee4ab27235729436c7505be034a7395050 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 11:50:47 -0700 Subject: [PATCH 06/10] Add unit tests for convert_to_bfloat --- backends/aoti/common_shims.cpp | 70 +++++++++++++++++++----- extension/llm/runner/test/CMakeLists.txt | 3 +- extension/llm/runner/test/targets.bzl | 10 ++++ extension/llm/runner/test/test_util.cpp | 59 ++++++++++++++++++++ extension/llm/runner/util.h | 25 ++------- 5 files changed, 134 insertions(+), 33 deletions(-) create mode 100644 extension/llm/runner/test/test_util.cpp diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 73cba6f11e7..4d15f870a41 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -50,14 +50,36 @@ AOTITorchError aoti_torch_get_storage_offset( } AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) { - std::vector strides(tensor->dim()); - auto tensor_strides = tensor->strides(); - for (ssize_t i = 0; i < tensor->dim(); i++) { - strides[i] = static_cast(tensor_strides[i]); + auto it = internal::tensor_to_strides.find(tensor); + bool needs_update = false; + + if (it == internal::tensor_to_strides.end()) { + needs_update = true; + } else { + // Check if cached values are still valid + auto tensor_strides = tensor->strides(); + if (it->second.size() != static_cast(tensor->dim())) { + needs_update = true; + } else { + for (int i = 0; i < tensor->dim(); i++) { + if (it->second[i] != tensor_strides[i]) { + needs_update = true; + break; + } + } + } + } + + if (needs_update) { + std::vector strides(tensor->dim()); + auto tensor_strides = tensor->strides(); + for (int i = 0; i < tensor->dim(); i++) { + strides[i] = tensor_strides[i]; + } + it = + internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides)) + .first; } - auto it = - internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides)) - .first; // For 0D tensors, data() returns nullptr on empty vectors, but we need to // return a valid pointer @@ -78,13 +100,35 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) { } AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) { - std::vector sizes(tensor->dim()); - auto tensor_sizes = tensor->sizes(); - for (ssize_t i = 0; i < tensor->dim(); i++) { - sizes[i] = static_cast(tensor_sizes[i]); + auto it = internal::tensor_to_sizes.find(tensor); + bool needs_update = false; + + if (it == internal::tensor_to_sizes.end()) { + needs_update = true; + } else { + // Check if cached values are still valid + auto tensor_sizes = tensor->sizes(); + if (it->second.size() != static_cast(tensor->dim())) { + needs_update = true; + } else { + for (int i = 0; i < tensor->dim(); i++) { + if (it->second[i] != tensor_sizes[i]) { + needs_update = true; + break; + } + } + } + } + + if (needs_update) { + std::vector sizes(tensor->dim()); + auto tensor_sizes = tensor->sizes(); + for (int i = 0; i < tensor->dim(); i++) { + sizes[i] = tensor_sizes[i]; + } + it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes)) + .first; } - auto it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes)) - .first; // For 0D tensors, data() returns nullptr on empty vectors, but we need to // return a valid pointer diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 934a5797da1..27edc6a9e62 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -19,7 +19,8 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp test_text_prefiller.cpp - test_text_decoder_runner.cpp test_multimodal_input.cpp test_wav_loader.cpp + test_text_decoder_runner.cpp test_multimodal_input.cpp test_util.cpp + test_wav_loader.cpp ) # Add LSan stub for Apple platforms diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 0571b39ccdb..1109ff315ac 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -45,6 +45,16 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "test_util", + srcs = ["test_util.cpp"], + deps = [ + "//executorch/extension/llm/runner:stats", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + ], + ) + runtime.cxx_test( name = "test_wav_loader", srcs = ["test_wav_loader.cpp"], diff --git a/extension/llm/runner/test/test_util.cpp b/extension/llm/runner/test/test_util.cpp new file mode 100644 index 00000000000..242e48e6871 --- /dev/null +++ b/extension/llm/runner/test/test_util.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include + +namespace { + +using ::executorch::aten::ScalarType; +using ::executorch::extension::make_tensor_ptr; +using ::executorch::extension::llm::convert_to_bfloat16; + +TEST(ConvertToBFloat16Test, ConvertsFloatTensorData) { + auto source_tensor = make_tensor_ptr( + {2, 2}, std::vector{0.0f, 1.5f, -2.0f, 3.25f}); + + auto result = convert_to_bfloat16(source_tensor); + ASSERT_TRUE(result.ok()); + auto bf16_tensor = *result; + + EXPECT_EQ(bf16_tensor->scalar_type(), ScalarType::BFloat16); + EXPECT_EQ(bf16_tensor->numel(), source_tensor->numel()); + + auto src_sizes = source_tensor->sizes(); + auto dst_sizes = bf16_tensor->sizes(); + ASSERT_EQ(dst_sizes.size(), src_sizes.size()); + for (size_t dim = 0; dim < dst_sizes.size(); ++dim) { + EXPECT_EQ(dst_sizes[dim], src_sizes[dim]); + } + + const auto* converted_data = bf16_tensor->const_data_ptr<::c10::BFloat16>(); + const auto* original_data = source_tensor->const_data_ptr(); + ASSERT_NE(converted_data, nullptr); + ASSERT_NE(original_data, nullptr); + + for (size_t i = 0; i < static_cast(source_tensor->numel()); ++i) { + EXPECT_NEAR(static_cast(converted_data[i]), original_data[i], 1e-2f); + } +} + +TEST(ConvertToBFloat16Test, RejectsNonFloatTensor) { + auto non_float_tensor = + make_tensor_ptr({3}, std::vector{1, 2, 3}); + + auto result = convert_to_bfloat16(non_float_tensor); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument); +} + +} // namespace diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 6c8406f81fc..2d924fc88a3 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -152,30 +152,17 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) { InvalidArgument, "BFloat16 conversion only supported from Float source data"); - size_t num_elements = src_tensor->numel(); - auto sizes = src_tensor->sizes(); - - // Allocate memory for bfloat16 data - auto* bf16_data = new uint16_t[num_elements]; + const auto num_elements = static_cast(src_tensor->numel()); const float* float_data = src_tensor->const_data_ptr(); - // Convert float to bfloat16 + auto bf16_tensor = ::executorch::extension::empty_like( + src_tensor, ::executorch::aten::ScalarType::BFloat16); + auto* bf16_data = bf16_tensor->mutable_data_ptr<::c10::BFloat16>(); for (size_t i = 0; i < num_elements; ++i) { - // bfloat16 is the upper 16 bits of float32 - uint32_t float_bits; - std::memcpy(&float_bits, &float_data[i], sizeof(float)); - - // Rounding: add 0x7FFF to round to nearest even - uint32_t rounding_bias = 0x7FFF + ((float_bits >> 16) & 1); - bf16_data[i] = static_cast((float_bits + rounding_bias) >> 16); + bf16_data[i] = ::c10::BFloat16(float_data[i]); } - // Create tensor with deleter to free allocated memory - return ::executorch::extension::from_blob( - bf16_data, - {sizes.begin(), sizes.end()}, - ::executorch::aten::ScalarType::BFloat16, - [](void* ptr) { delete[] static_cast(ptr); }); + return bf16_tensor; } } // namespace llm From 88873b7cda7b0e5af18354c5930dc3e20a9b4125 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 11:51:27 -0700 Subject: [PATCH 07/10] Lint --- extension/llm/runner/test/CMakeLists.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 27edc6a9e62..81b69c0ab9a 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -18,8 +18,12 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs - test_generation_config.cpp test_text_llm_runner.cpp test_text_prefiller.cpp - test_text_decoder_runner.cpp test_multimodal_input.cpp test_util.cpp + test_generation_config.cpp + test_text_llm_runner.cpp + test_text_prefiller.cpp + test_text_decoder_runner.cpp + test_multimodal_input.cpp + test_util.cpp test_wav_loader.cpp ) From f40b1fb64cc6d22869a2f592f3246203db740353 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 13:26:13 -0700 Subject: [PATCH 08/10] Address comments --- backends/aoti/common_shims.cpp | 40 +++++++++++------------- examples/models/voxtral/README.md | 51 +++++++++++++++++++++++++++++++ extension/llm/runner/util.h | 5 +-- 3 files changed, 72 insertions(+), 24 deletions(-) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 4d15f870a41..f0c134a716c 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -56,18 +56,16 @@ AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) { if (it == internal::tensor_to_strides.end()) { needs_update = true; } else { - // Check if cached values are still valid + // CRITICAL: Multimodal models reuse tensors with different shapes across + // executions (e.g., variable-length audio). We MUST validate cached + // metadata matches current tensor state, or CUDA kernels will receive + // incorrect shapes leading to memory corruption and segfaults. auto tensor_strides = tensor->strides(); - if (it->second.size() != static_cast(tensor->dim())) { - needs_update = true; - } else { - for (int i = 0; i < tensor->dim(); i++) { - if (it->second[i] != tensor_strides[i]) { - needs_update = true; - break; - } - } - } + needs_update = !std::equal( + it->second.begin(), + it->second.end(), + tensor_strides.begin(), + tensor_strides.end()); } if (needs_update) { @@ -106,18 +104,16 @@ AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) { if (it == internal::tensor_to_sizes.end()) { needs_update = true; } else { - // Check if cached values are still valid + // CRITICAL: Multimodal models reuse tensors with different shapes across + // executions (e.g., variable-length audio). We MUST validate cached + // metadata matches current tensor state, or CUDA kernels will receive + // incorrect shapes leading to memory corruption and segfaults. auto tensor_sizes = tensor->sizes(); - if (it->second.size() != static_cast(tensor->dim())) { - needs_update = true; - } else { - for (int i = 0; i < tensor->dim(); i++) { - if (it->second[i] != tensor_sizes[i]) { - needs_update = true; - break; - } - } - } + needs_update = !std::equal( + it->second.begin(), + it->second.end(), + tensor_sizes.begin(), + tensor_sizes.end()); } if (needs_update) { diff --git a/examples/models/voxtral/README.md b/examples/models/voxtral/README.md index 4e9ddcf34a4..861043fe2a7 100644 --- a/examples/models/voxtral/README.md +++ b/examples/models/voxtral/README.md @@ -36,6 +36,29 @@ optimum-cli export executorch \ This exports Voxtral with XNNPack backend acceleration and 4-bit weight/8-bit activation linear quantization. +## CUDA Support +If your environment has CUDA support, you can enable the runner to run on CUDA for improved performance. Follow the export and runtime commands below: + +**Note:** We are currently working on quantization support for CUDA. Currently, only bfloat16 dtype is supported for CUDA execution. + +### Exporting with CUDA +``` +optimum-cli export executorch \ + --model "mistralai/Voxtral-Mini-3B-2507" \ + --task "multimodal-text-to-text" \ + --recipe "cuda" \ + --dtype bfloat16 \ + --device cuda \ + --max_seq_len 1024 \ + --output_dir="voxtral" +``` + +This will generate: +- `model.pte` - The exported model +- `aoti_cuda_blob.ptd` - The CUDA kernel blob required for runtime + +See the "Building the multimodal runner" section below for instructions on building with CUDA support, and the "Running the model" section for runtime instructions. + # Running the model To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's MultiModal runner API. The Voxtral runner will do the following things: @@ -56,6 +79,8 @@ python -m executorch.extension.audio.mel_spectrogram --feature_size 128 --stack_ ``` ## Building the multimodal runner + +### Building for CPU (XNNPack) ``` # Build and install ExecuTorch cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -DEXECUTORCH_ENABLE_LOGGING=ON && cmake --build cmake-out -j16 --target install --config Release @@ -64,6 +89,26 @@ cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out - cmake -DCMAKE_INSTALL_PREFIX=cmake-out -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release -Bcmake-out/examples/models/voxtral examples/models/voxtral && cmake --build cmake-out/examples/models/voxtral -j16 --config Release ``` +### Building for CUDA +``` +# Install ExecuTorch with CUDA support +CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh + +# Build the multimodal runner with CUDA +cmake --preset llm \ + -DEXECUTORCH_BUILD_CUDA=ON \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -Bcmake-out -S. +cmake --build cmake-out -j16 --target install --config Release + +cmake -DEXECUTORCH_BUILD_CUDA=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -Sexamples/models/voxtral \ + -Bcmake-out/examples/models/voxtral/ +cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release +``` + ## Running the model You can download the `tekken.json` tokenizer from [Voxtral's HuggingFace repo](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507). @@ -88,6 +133,12 @@ If you already have a preprocessed mel spectrogram saved as a `.bin` file, you c --audio_path path/to/preprocessed_audio.bin ``` + +**For CUDA:** Add the `--data_path` argument to provide the CUDA kernel blob to the commands above: +``` + --data_path path/to/aoti_cuda_blob.ptd +``` + Example output: ``` The speaker in this audio seems to be talking about their concerns about a device called the model or maybe they're just talking about the model in general. They mention that the model was trained with the speaker for inference, which suggests that diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index 2d924fc88a3..c587058c101 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -157,9 +157,10 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) { auto bf16_tensor = ::executorch::extension::empty_like( src_tensor, ::executorch::aten::ScalarType::BFloat16); - auto* bf16_data = bf16_tensor->mutable_data_ptr<::c10::BFloat16>(); + auto* bf16_data = + bf16_tensor->mutable_data_ptr<::executorch::aten::ScalarType::BFloat16>(); for (size_t i = 0; i < num_elements; ++i) { - bf16_data[i] = ::c10::BFloat16(float_data[i]); + bf16_data[i] = ::executorch::aten::ScalarType::BFloat16(float_data[i]); } return bf16_tensor; From 7ab6b25b74467f5117b09a02aa3b5763cb3b2f05 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 13:44:50 -0700 Subject: [PATCH 09/10] Fix typo --- extension/llm/runner/util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index c587058c101..a267e666bc5 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -160,7 +160,7 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) { auto* bf16_data = bf16_tensor->mutable_data_ptr<::executorch::aten::ScalarType::BFloat16>(); for (size_t i = 0; i < num_elements; ++i) { - bf16_data[i] = ::executorch::aten::ScalarType::BFloat16(float_data[i]); + bf16_data[i] = ::executorch::aten::BFloat16(float_data[i]); } return bf16_tensor; From afc2159af1781dfa9f756b10d4737e8faf35bac0 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 10 Oct 2025 15:09:19 -0700 Subject: [PATCH 10/10] Fix typo --- extension/llm/runner/util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index a267e666bc5..ec08ecfb647 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -158,7 +158,7 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) { auto bf16_tensor = ::executorch::extension::empty_like( src_tensor, ::executorch::aten::ScalarType::BFloat16); auto* bf16_data = - bf16_tensor->mutable_data_ptr<::executorch::aten::ScalarType::BFloat16>(); + bf16_tensor->mutable_data_ptr<::executorch::aten::BFloat16>(); for (size_t i = 0; i < num_elements; ++i) { bf16_data[i] = ::executorch::aten::BFloat16(float_data[i]); }