Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 130 additions & 20 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ jobs:
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda

test-voxtral-cuda-e2e:
name: test-voxtral-cuda-e2e
export-voxtral-cuda-artifact:
name: export-voxtral-cuda-artifact
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
Expand All @@ -104,6 +104,7 @@ jobs:
gpu-arch-version: 12.6
use-custom-docker-registry: false
submodules: recursive
upload-artifact: voxtral-cuda-export
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
set -eux
Expand All @@ -118,6 +119,7 @@ jobs:
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
pip install mistral-common librosa
pip list
echo "::endgroup::"

echo "::group::Export Voxtral"
Expand All @@ -129,9 +131,58 @@ jobs:
--device cuda \
--max_seq_len 1024 \
--output_dir ./
python -m executorch.extension.audio.mel_spectrogram \
--feature_size 128 \
--stack_output \
--max_audio_len 300 \
--output_file voxtral_preprocessor.pte

test -f model.pte
test -f aoti_cuda_blob.ptd
test -f voxtral_preprocessor.pte
echo "::endgroup::"

echo "::group::Build Voxtral Runner"
echo "::group::Store Voxtral Artifacts"
mkdir -p "${RUNNER_ARTIFACT_DIR}"
cp model.pte "${RUNNER_ARTIFACT_DIR}/"
cp aoti_cuda_blob.ptd "${RUNNER_ARTIFACT_DIR}/"
cp voxtral_preprocessor.pte "${RUNNER_ARTIFACT_DIR}/"
ls -al "${RUNNER_ARTIFACT_DIR}"
echo "::endgroup::"

benchmark-voxtral-cuda:
name: benchmark-voxtral-cuda
needs: export-voxtral-cuda-artifact
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
strategy:
fail-fast: false
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: 12.6
use-custom-docker-registry: false
submodules: recursive
download-artifact: voxtral-cuda-export
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
set -eux

echo "::group::Setup ExecuTorch Requirements"
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
pip list
echo "::endgroup::"

echo "::group::Prepare Voxtral Artifacts"
cp "${RUNNER_ARTIFACT_DIR}/model.pte" .
cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" .
ls -al model.pte aoti_cuda_blob.ptd
echo "::endgroup::"

echo "::group::Build Voxtral Benchmark"
cmake -DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_CUDA=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
Expand All @@ -142,31 +193,90 @@ jobs:
cmake --build cmake-out -j$(( $(nproc) - 1 )) --target voxtral_runner
echo "::endgroup::"

echo "::group::Run Voxtral Benchmark"

export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd

echo "::endgroup::"

test-voxtral-cuda-e2e:
name: test-voxtral-cuda-e2e
needs: export-voxtral-cuda-artifact
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
strategy:
fail-fast: false
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: 12.6
use-custom-docker-registry: false
submodules: recursive
download-artifact: voxtral-cuda-export
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
set -eux

echo "::group::Setup ExecuTorch Requirements"
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
pip list
echo "::endgroup::"

echo "::group::Prepare Voxtral Artifacts"
cp "${RUNNER_ARTIFACT_DIR}/model.pte" .
cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" .
cp "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" .
TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main/tekken.json"
curl -L $TOKENIZER_URL -o tekken.json
ls -al model.pte aoti_cuda_blob.ptd voxtral_preprocessor.pte tekken.json
echo "::endgroup::"

echo "::group::Download Test Audio File"
AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav"
curl -L $AUDIO_URL -o poem.wav
echo "::endgroup::"

echo "::group::Build Voxtral Runner"
cmake --preset llm \
-DEXECUTORCH_BUILD_CUDA=ON \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-Bcmake-out -S.
cmake --build cmake-out -j$(( $(nproc) - 1 )) --target install --config Release

cmake -DEXECUTORCH_BUILD_CUDA=ON \
-DCMAKE_BUILD_TYPE=Release \
-Sexamples/models/voxtral \
-Bcmake-out/examples/models/voxtral/
cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release
echo "::endgroup::"

echo "::group::Run Voxtral Runner"
# Capture output and allow exit code 139 if we have the expected printout
set +e
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
OUTPUT=$(cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd 2>&1)
OUTPUT=$(cmake-out/examples/models/voxtral/voxtral_runner \
--model_path model.pte \
--data_path aoti_cuda_blob.ptd \
--tokenizer_path tekken.json \
--audio_path poem.wav \
--processor_path voxtral_preprocessor.pte \
--temperature 0 2>&1)
EXIT_CODE=$?
set -e

echo "$OUTPUT"

# Check if the output contains "Run latency (ms):"
if echo "$OUTPUT" | grep -q "Run latency (ms):"; then
echo "Found expected output: 'Run latency (ms):'"
if [ $EXIT_CODE -eq 139 ]; then
echo "Exit code 139 (segfault) detected, but passing since we have the expected output"
exit 0
elif [ $EXIT_CODE -ne 0 ]; then
echo "Unexpected exit code: $EXIT_CODE"
exit $EXIT_CODE
else
echo "Command succeeded with exit code 0"
exit 0
fi
else
echo "Expected output 'Run latency (ms):' not found in output"
if ! echo "$OUTPUT" | grep -iq "poem"; then
echo "Expected output 'poem' not found in output"
exit 1
fi

if [ $EXIT_CODE -ne 0 ]; then
echo "Unexpected exit code: $EXIT_CODE"
exit $EXIT_CODE
fi
echo "::endgroup::"
41 changes: 39 additions & 2 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,32 @@ AOTITorchError aoti_torch_get_storage_offset(

AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
auto it = internal::tensor_to_strides.find(tensor);
bool needs_update = false;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make docblock something like this?

// CRITICAL: Multimodal models reuse tensors with different shapes across
// executions (e.g., variable-length audio). We MUST validate cached metadata
// matches current tensor state, or CUDA kernels will receive incorrect shapes
// leading to memory corruption and segfaults.

if (it == internal::tensor_to_strides.end()) {
needs_update = true;
} else {
// CRITICAL: Multimodal models reuse tensors with different shapes across
// executions (e.g., variable-length audio). We MUST validate cached
// metadata matches current tensor state, or CUDA kernels will receive
// incorrect shapes leading to memory corruption and segfaults.
auto tensor_strides = tensor->strides();
needs_update = !std::equal(
it->second.begin(),
it->second.end(),
tensor_strides.begin(),
tensor_strides.end());
}

if (needs_update) {
std::vector<int64_t> strides(tensor->dim());
auto tensor_strides = tensor->strides();
for (int i = 0; i < tensor->dim(); i++) {
strides[i] = tensor_strides[i];
}
it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first;
it =
internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides))
.first;
}

// For 0D tensors, data() returns nullptr on empty vectors, but we need to
Expand All @@ -80,13 +99,31 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {

AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
auto it = internal::tensor_to_sizes.find(tensor);
bool needs_update = false;

if (it == internal::tensor_to_sizes.end()) {
needs_update = true;
} else {
// CRITICAL: Multimodal models reuse tensors with different shapes across
// executions (e.g., variable-length audio). We MUST validate cached
// metadata matches current tensor state, or CUDA kernels will receive
// incorrect shapes leading to memory corruption and segfaults.
auto tensor_sizes = tensor->sizes();
needs_update = !std::equal(
it->second.begin(),
it->second.end(),
tensor_sizes.begin(),
tensor_sizes.end());
}

if (needs_update) {
std::vector<int64_t> sizes(tensor->dim());
auto tensor_sizes = tensor->sizes();
for (int i = 0; i < tensor->dim(); i++) {
sizes[i] = tensor_sizes[i];
}
it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first;
it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes))
.first;
}

// For 0D tensors, data() returns nullptr on empty vectors, but we need to
Expand Down
10 changes: 8 additions & 2 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ class ET_EXPERIMENTAL CudaBackend final
Span<EValue*> args) const override {
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Need to re-register all the symbols from the so_handle hosted by this
// CudaBackend instance. The reason is that these symbols are
// static/singleton across the whole process. When we share multiple methods
// (meaning multiple so_handle) in the same process, we need to re-register
// the symbols from the so_handle that is being used in this execution.
ET_CHECK_OK_OR_RETURN_ERROR(
register_shared_library_functions(handle->so_handle));

Comment on lines +168 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're loading the model once and doing execute/inference multiple times, it will register multiple times, no?

Can you do something like this?

  void* last_registered_handle = nullptr;

  if (handle->so_handle != last_registered_handle) {
      ET_CHECK_OK_OR_RETURN_ERROR(
          register_shared_library_functions(handle->so_handle));
      last_registered_handle = handle->so_handle;
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the so_handle won't change. It's just we are mapping the symbols differently, especially AOTInductorModelContainerRun. Let's say we do the following:

  1. load(token_embedding)
  2. load(audio_encoder)
  3. load(text_decoder)
  4. run(audio_encoder) <-- here AOTInductorModelContainerRun maps to the symbol in text_decoder.so, so we need to remap the symbol to audio_encoder.so

Copy link
Contributor

@mergennachin mergennachin Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larryliu0820

Can you store the AOTInductorModelContainerRunFunc inside AOTIDelegateHandle?

  struct AOTIDelegateHandle {
      void* so_handle;
      std::string so_path;
      AOTInductorModelContainerHandle container_handle;
      void* cuda_stream;

      AOTInductorModelContainerRunFunc run_func;
      // ... etc for all symbols
  };
 Result<DelegateHandle*> init(...) const override {
      AOTIDelegateHandle* handle = new AOTIDelegateHandle();
      handle->so_handle = so_handle;

      // Load symbols into THIS handle's struct (not global)
      handle->run_func = reinterpret_cast<AOTInductorModelContainerRunFunc>(
          dlsym(so_handle, "AOTInductorModelContainerRun"));
      // ... etc

      ET_CHECK_OR_RETURN_ERROR(
          handle->run_func != nullptr,
          AccessFailed,
          "Failed to load AOTInductorModelContainerRun");

      return (DelegateHandle*)handle;
  }
  Error execute(..., DelegateHandle* handle_, ...) const override {
      AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

      // NO re-registration, use the handle's local symbols
      AOTIRuntimeError error = handle->run_func(
          ...)

      // ... rest of execution ...
  }

size_t n_inputs;
AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs);

Expand Down Expand Up @@ -223,7 +231,6 @@ class ET_EXPERIMENTAL CudaBackend final
"Failed to copy input %d from CPU to GPU",
i);
}
ET_LOG(Info, "Inputs copied to GPU");
// Process output tensors: create GPU counterparts for ExecuTorch CPU
// tensors
for (int i = 0; i < n_outputs; i++) {
Expand Down Expand Up @@ -253,7 +260,6 @@ class ET_EXPERIMENTAL CudaBackend final

gpu_outputs[i] = gpu_output_handle;
}
ET_LOG(Info, "Outputs created on GPU");
// Run AOTI container with GPU tensors
AOTIRuntimeError error = AOTInductorModelContainerRun(
handle->container_handle,
Expand Down
7 changes: 7 additions & 0 deletions examples/models/voxtral/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ list(
extension_flat_tensor
)

# Link CUDA backend
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda)
executorch_target_link_options_shared_lib(aoti_cuda)
endif()

# Add tokenizers
list(APPEND link_libraries tokenizers::tokenizers)

Expand Down
51 changes: 51 additions & 0 deletions examples/models/voxtral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ optimum-cli export executorch \

This exports Voxtral with XNNPack backend acceleration and 4-bit weight/8-bit activation linear quantization.

## CUDA Support
If your environment has CUDA support, you can enable the runner to run on CUDA for improved performance. Follow the export and runtime commands below:

**Note:** We are currently working on quantization support for CUDA. Currently, only bfloat16 dtype is supported for CUDA execution.

### Exporting with CUDA
```
optimum-cli export executorch \
--model "mistralai/Voxtral-Mini-3B-2507" \
--task "multimodal-text-to-text" \
--recipe "cuda" \
--dtype bfloat16 \
--device cuda \
--max_seq_len 1024 \
--output_dir="voxtral"
```

This will generate:
- `model.pte` - The exported model
- `aoti_cuda_blob.ptd` - The CUDA kernel blob required for runtime

See the "Building the multimodal runner" section below for instructions on building with CUDA support, and the "Running the model" section for runtime instructions.

# Running the model
To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's MultiModal runner API.
The Voxtral runner will do the following things:
Expand All @@ -56,6 +79,8 @@ python -m executorch.extension.audio.mel_spectrogram --feature_size 128 --stack_
```

## Building the multimodal runner

### Building for CPU (XNNPack)
```
# Build and install ExecuTorch
cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -DEXECUTORCH_ENABLE_LOGGING=ON && cmake --build cmake-out -j16 --target install --config Release
Expand All @@ -64,6 +89,26 @@ cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -
cmake -DCMAKE_INSTALL_PREFIX=cmake-out -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release -Bcmake-out/examples/models/voxtral examples/models/voxtral && cmake --build cmake-out/examples/models/voxtral -j16 --config Release
```

### Building for CUDA
```
# Install ExecuTorch with CUDA support
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh

# Build the multimodal runner with CUDA
cmake --preset llm \
-DEXECUTORCH_BUILD_CUDA=ON \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-Bcmake-out -S.
cmake --build cmake-out -j16 --target install --config Release

cmake -DEXECUTORCH_BUILD_CUDA=ON \
-DCMAKE_BUILD_TYPE=Release \
-Sexamples/models/voxtral \
-Bcmake-out/examples/models/voxtral/
cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release
```

## Running the model
You can download the `tekken.json` tokenizer from [Voxtral's HuggingFace repo](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507).

Expand All @@ -88,6 +133,12 @@ If you already have a preprocessed mel spectrogram saved as a `.bin` file, you c
--audio_path path/to/preprocessed_audio.bin
```


**For CUDA:** Add the `--data_path` argument to provide the CUDA kernel blob to the commands above:
```
--data_path path/to/aoti_cuda_blob.ptd
```

Example output:
```
The speaker in this audio seems to be talking about their concerns about a device called the model or maybe they're just talking about the model in general. They mention that the model was trained with the speaker for inference, which suggests that
Expand Down
Loading
Loading