Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,12 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
list(APPEND _executorch_extensions extension_module_static)
endif()

if(EXECUTORCH_BUILD_EXTENSION_AUDIO)
message(STATUS "Audio/ASR extension enabled")
endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM)
message(STATUS "LLM extension enabled")
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
set(SUPPORT_REGEX_LOOKAHEAD ON)
# llama/runner/CMakeLists.txt builds a shared library libllama_runner.so
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ if [ "$BUILD_AARCH64" = true ]; then
-DCMAKE_BUILD_TYPE=$BUILD_TYPE \
-DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
-DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Expand Down Expand Up @@ -150,6 +151,7 @@ if [ "$BUILD_X86_64" = true ]; then
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
-DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
-DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/whisper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(_qnn_whisper_runner__srcs
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp
${EXECUTORCH_ROOT}/extension/audio/runner/asr_runner.h
)

# build qnn whisper runner
Expand Down
12 changes: 10 additions & 2 deletions examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*/

#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
#include <executorch/extension/llm/runner/audio.h>
#include <executorch/runtime/platform/log.h>
#include <gflags/gflags.h>
#include <fstream>
Expand Down Expand Up @@ -97,7 +98,7 @@ std::vector<std::vector<std::vector<char>>> parse_input_list_file(
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
// create llama runner
example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_json_path);
example::WhisperRunner runner(FLAGS_model_path, FLAGS_tokenizer_json_path);

std::vector<std::vector<std::vector<char>>> multi_turns_input_buffers =
parse_input_list_file(FLAGS_input_list_path);
Expand All @@ -110,7 +111,14 @@ int main(int argc, char** argv) {
}
};
// generate tokens
runner.transcribe(FLAGS_seq_len, multi_turns_input_buffers[iter], callback);
executorch::extension::llm::Audio audio{
std::vector<uint8_t>(
multi_turns_input_buffers[iter][0].begin(),
multi_turns_input_buffers[iter][0].end()),
1,
80,
3000};
runner.transcribe(FLAGS_seq_len, audio, callback);
auto output_file_name =
FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt";
std::ofstream fout(output_file_name);
Expand Down
29 changes: 12 additions & 17 deletions examples/qualcomm/oss_scripts/whisper/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ static constexpr auto kDecoderStartTokenId = "decoder_start_token_id";
static constexpr auto kEosId = "get_eos_id";
static constexpr auto kMaxContextLen = "get_max_context_len";
} // namespace
Runner::Runner(
WhisperRunner::WhisperRunner(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should refactor out general purpose ASR logic out of this file and put into asr_runner.h/cpp

const std::string& model_path,
const std::string& tokenizer_json_path)
: tokenizer_json_path_(tokenizer_json_path) {
encoder_ = std::make_unique<WhisperEncoder>(model_path);
decoder_ = std::make_unique<WhisperDecoder>(model_path);
tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
}
bool Runner::is_loaded() const {
bool WhisperRunner::is_loaded() const {
return encoder_->is_method_loaded() && decoder_->is_method_loaded() &&
tokenizer_->is_loaded() && sampler_;
}

Error Runner::load() {
Error WhisperRunner::load() {
if (is_loaded()) {
return Error::Ok;
}
Expand Down Expand Up @@ -108,33 +108,28 @@ Error Runner::load() {

return Error::Ok;
}
uint64_t Runner::logits_to_token(
uint64_t WhisperRunner::logits_to_token(
const executorch::aten::Tensor& logits_tensor) {
return sampler_->sample(logits_tensor.data_ptr<float>());
}
/**
* @param inputs: A vector containing one element: a vector of bytes that
* encodes a float tensor in little-endian byte order.
*
*/
Error Runner::transcribe(
Error WhisperRunner::transcribe(
int32_t seq_len,
std::vector<std::vector<char>>& inputs,
std::function<void(const std::string&)> token_callback) {
executorch::extension::llm::Audio& audio,
std::function<void(const std::string&)> token_callback,
std::function<void(const executorch::extension::llm::Stats&)>
stats_callback) {
if (!is_loaded()) {
stats_.model_load_start_ms = time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
stats_.model_load_end_ms = time_in_ms();
}
ET_CHECK_MSG(inputs.size() == 1, "The input size of whisper should be one.");

ET_LOG(Info, "Start Encoding");
stats_.encoder_inference_start_ms = time_in_ms();
auto input_features_tensor_ptr = from_blob(
inputs[0].data(),
audio.data.data(),
// (1, processor.feature_extractor.feature_size,
// processor.feature_extractor.nb_max_frames)
{1, 80, 3000},
{audio.batch_size, audio.n_bins, audio.n_frames}, // {1, 80, 3000}
ScalarType::Float);
Result<Tensor> encoder_out = encoder_->encode(input_features_tensor_ptr);
auto encoder_out_tensor_ptr = make_tensor_ptr(encoder_out.get());
Expand Down Expand Up @@ -188,7 +183,7 @@ Error Runner::transcribe(
return Error::Ok;
}

Error Runner::print_performance() {
Error WhisperRunner::print_performance() {
ET_LOG(Info, "\tTotal Generated token:\t\t\t\t%ld", num_generated_token_);

ET_LOG(
Expand Down
13 changes: 9 additions & 4 deletions examples/qualcomm/oss_scripts/whisper/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/decoder.h>
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/encoder.h>
#include <executorch/extension/audio/runner/asr_runner.h>
#include <executorch/extension/llm/runner/audio.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/runtime/core/error.h>
#include <pytorch/tokenizers/tokenizer.h>
Expand All @@ -24,9 +27,9 @@

namespace example {

class Runner {
class WhisperRunner : public executorch::extension::audio::ASRRunner {
public:
explicit Runner(
explicit WhisperRunner(
const std::string& model_path,
const std::string& tokenizer_json_path);

Expand All @@ -51,8 +54,10 @@ class Runner {
executorch::runtime::Error load();
executorch::runtime::Error transcribe(
int32_t seq_len,
std::vector<std::vector<char>>& inputs,
std::function<void(const std::string&)> token_callback = {});
executorch::extension::llm::Audio& audio,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::extension::llm::Stats&)>
stats_callback = {});

private:
executorch::runtime::Error print_performance();
Expand Down
29 changes: 27 additions & 2 deletions extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
)
endif()

if(EXECUTORCH_BUILD_LLAMA_JNI)
if(EXECUTORCH_BUILD_EXTENSION_LLM)
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
list(APPEND link_libraries extension_llm_runner)
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1)
target_compile_definitions(
executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1
)

if(QNN_SDK_ROOT)
target_sources(
Expand Down Expand Up @@ -221,6 +223,29 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
endif()
endif()

if(EXECUTORCH_BUILD_EXTENSION_AUDIO)
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
target_compile_definitions(
executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1
)

if(QNN_SDK_ROOT)
target_sources(
executorch_jni
PRIVATE
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp
)

target_include_directories(
executorch_jni
PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner
)
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1)
endif()
endif()

target_include_directories(
executorch_jni
PRIVATE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch.extension.audio;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.io.File;
import org.pytorch.executorch.ExecuTorchRuntime;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.annotations.Experimental;

/**
* ASRModule is a wrapper around the Executorch ASR runners like Whisper runner.
*
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class ASRModule {

@DoNotStrip private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(
String modulePath, String tokenizerPath);

public ASRModule(
String modulePath, String tokenizerPath) {
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();

File modelFile = new File(modulePath);
if (!modelFile.canRead() || !modelFile.isFile()) {
throw new RuntimeException("Cannot load model path " + modulePath);
}
File tokenizerFile = new File(tokenizerPath);
if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) {
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
}
mHybridData = initHybrid(modulePath, tokenizerPath);
}

public void resetNative() {
mHybridData.resetNative();
}

@DoNotStrip
public native int transcribe(
int seqLen,
byte[][] inputs,
LlmCallback callback,
int n_bins,
int n_frames);

/** Force loading the module. Otherwise the model is loaded during first generate(). */
@DoNotStrip
public native int load();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/** Extension for audio and ASR related use cases for ExecuTorch Android Java/JNI package. */
package org.pytorch.executorch.extension.audio;
2 changes: 1 addition & 1 deletion extension/android/jni/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
srcs = ["jni_layer.cpp", "jni_layer_llama.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"],
allow_jni_merging = False,
compiler_flags = ET_JNI_COMPILER_FLAGS + [
"-DEXECUTORCH_BUILD_LLAMA_JNI",
"-DEXECUTORCH_BUILD_EXTENSION_LLM",
],
soname = "libexecutorch.$(ext)",
visibility = ["PUBLIC"],
Expand Down
9 changes: 8 additions & 1 deletion extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,13 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
};
} // namespace executorch::extension

#ifdef EXECUTORCH_BUILD_LLAMA_JNI
#ifdef EXECUTORCH_BUILD_EXTENSION_AUDIO
extern void register_natives_for_asr();
#else
void register_natives_for_asr() {}
#endif

#ifdef EXECUTORCH_BUILD_EXTENSION_LLM
extern void register_natives_for_llm();
#else
// No op if we don't build LLM
Expand All @@ -526,6 +532,7 @@ void register_natives_for_training() {}
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
executorch::extension::ExecuTorchJni::registerNatives();
register_natives_for_asr();
register_natives_for_llm();
register_natives_for_runtime();
register_natives_for_training();
Expand Down
Loading
Loading