From 8afb4eed32e9c45c1a852d0cdec68a2f4245efd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Tue, 23 Sep 2025 13:24:55 +0200 Subject: [PATCH 1/7] initial draft of token batching --- .../common/rnexecutorch/models/llm/LLM.cpp | 18 ++++++++++++++++++ .../common/runner/text_token_generator.h | 12 ++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 82a804852..825702d61 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -1,7 +1,9 @@ #include "LLM.h" +#include #include #include +#include namespace rnexecutorch::models::llm { using namespace facebook; @@ -41,7 +43,23 @@ void LLM::generate(std::string input, std::shared_ptr callback) { std::to_string(static_cast(error))); } } +// // Sequence counter to maintain callback order if needed +// auto sequenceCounter = std::make_shared>(0); +// // Create a native callback that uses the thread pool for JS callback +// execution auto nativeCallback = [this, callback, sequenceCounter](const +// std::string &token) { +// // Get current sequence number +// uint64_t currentSeq = sequenceCounter->fetch_add(1); + +// // Submit callback execution to thread pool +// threads::GlobalThreadPool::detach([this, callback, token, currentSeq]() { +// // Execute the JS callback via callInvoker on the JS thread +// callInvoker->invokeAsync([callback, token](jsi::Runtime &runtime) { +// callback->call(runtime, jsi::String::createFromUtf8(runtime, token)); +// }); +// }); +// }; void LLM::interrupt() { if (!runner || !runner->is_loaded()) { throw std::runtime_error("Can't interrupt a model that's not loaded!"); diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index cc69e3b73..a09f82a34 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -11,6 +11,7 @@ #include "stats.h" #include "text_decoder_runner.h" +#include #include #include #include @@ -27,7 +28,7 @@ class TextTokenGenerator { Stats *stats) : tokenizer_(tokenizer), text_decoder_runner_(text_decoder_runner), eos_ids_(std::move(eos_ids)), use_kv_cache_(use_kv_cache), - stats_(stats) {} + stats_(stats), timestamp_(std::chrono::high_resolution_clock::now()) {} /** * Token generation loop. @@ -109,9 +110,14 @@ class TextTokenGenerator { token_cache.push_back(static_cast(cur_token)); const std::string cache_decoded = tokenizer_->Decode(token_cache); - if (cache_decoded != "�" && cache_decoded != " �") { + if (!cache_decoded.ends_with("�") && !cache_decoded.ends_with(" �") && + (token_cache.size() > 10 || + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timestamp_) > + interval_)) { token_callback(cache_decoded); token_cache.clear(); + timestamp_ = std::chrono::high_resolution_clock::now(); } if (should_stop_) { @@ -138,6 +144,8 @@ class TextTokenGenerator { TextDecoderRunner *text_decoder_runner_; std::unique_ptr> eos_ids_; bool use_kv_cache_; + std::chrono::milliseconds interval_{120}; + std::chrono::high_resolution_clock::time_point timestamp_; // state machine bool should_stop_ = false; From 5b015e97f20d46dae0a1fd00299a2506946d3ebc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Wed, 24 Sep 2025 12:33:16 +0200 Subject: [PATCH 2/7] reused runner.stats for token data --- .../host_objects/ModelHostObject.h | 5 +++++ .../common/rnexecutorch/models/llm/LLM.cpp | 20 ++++--------------- .../common/rnexecutorch/models/llm/LLM.h | 3 ++- .../common/runner/runner.h | 3 ++- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 97f769e3a..f0ddca803 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -115,6 +115,11 @@ template class ModelHostObject : public JsiHostObject { ModelHostObject, synchronousHostFunction<&Model::interrupt>, "interrupt")); + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + synchronousHostFunction<&Model::getGeneratedTokenCount>, + "getGeneratedTokenCount")); + addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 825702d61..a7738f60e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -43,23 +43,7 @@ void LLM::generate(std::string input, std::shared_ptr callback) { std::to_string(static_cast(error))); } } -// // Sequence counter to maintain callback order if needed -// auto sequenceCounter = std::make_shared>(0); -// // Create a native callback that uses the thread pool for JS callback -// execution auto nativeCallback = [this, callback, sequenceCounter](const -// std::string &token) { -// // Get current sequence number -// uint64_t currentSeq = sequenceCounter->fetch_add(1); - -// // Submit callback execution to thread pool -// threads::GlobalThreadPool::detach([this, callback, token, currentSeq]() { -// // Execute the JS callback via callInvoker on the JS thread -// callInvoker->invokeAsync([callback, token](jsi::Runtime &runtime) { -// callback->call(runtime, jsi::String::createFromUtf8(runtime, token)); -// }); -// }); -// }; void LLM::interrupt() { if (!runner || !runner->is_loaded()) { throw std::runtime_error("Can't interrupt a model that's not loaded!"); @@ -67,6 +51,10 @@ void LLM::interrupt() { runner->stop(); } +std::size_t getGeneratedTokenCount() const noexcept { + return runner->stats_.num_generated_tokens; +} + std::size_t LLM::getMemoryLowerBound() const noexcept { return memorySizeLowerBound; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index c6a72f861..bf2f3ec81 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -21,7 +21,8 @@ class LLM { void generate(std::string input, std::shared_ptr callback); void interrupt(); void unload() noexcept; - std::size_t getMemoryLowerBound() const noexcept; + size_t getGeneratedTokenCount() const noexcept; + size_t getMemoryLowerBound() const noexcept; private: size_t memorySizeLowerBound; diff --git a/packages/react-native-executorch/common/runner/runner.h b/packages/react-native-executorch/common/runner/runner.h index 4162d4485..11cb1796b 100644 --- a/packages/react-native-executorch/common/runner/runner.h +++ b/packages/react-native-executorch/common/runner/runner.h @@ -45,6 +45,8 @@ class Runner : public executorch::extension::llm::IRunner { ::executorch::runtime::Error warmup(const std::string &prompt); void stop(); + ::executorch::extension::llm::Stats stats_; + private: float temperature_; bool shouldStop_{false}; @@ -61,7 +63,6 @@ class Runner : public executorch::extension::llm::IRunner { text_token_generator_; // stats - ::executorch::extension::llm::Stats stats_; }; } // namespace example From 5736da812bc4728ba6f8564e33b6c1cf6aab22bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Fri, 26 Sep 2025 11:04:07 +0200 Subject: [PATCH 3/7] added token baatching to llms --- .../host_objects/ModelHostObject.h | 8 +++++ .../common/rnexecutorch/models/llm/LLM.cpp | 15 +++++++-- .../common/rnexecutorch/models/llm/LLM.h | 2 ++ .../common/runner/runner.cpp | 9 +++++ .../common/runner/runner.h | 4 +-- .../common/runner/text_token_generator.h | 28 ++++++++++++---- .../src/controllers/LLMController.ts | 33 ++++++++++++++++--- .../natural_language_processing/useLLM.ts | 19 +++++++++++ .../react-native-executorch/src/types/llm.ts | 3 ++ 9 files changed, 105 insertions(+), 16 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index f0ddca803..501b91da8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -120,6 +120,14 @@ template class ModelHostObject : public JsiHostObject { synchronousHostFunction<&Model::getGeneratedTokenCount>, "getGeneratedTokenCount")); + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::setCountInterval>, + "setCountInterval")); + + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + synchronousHostFunction<&Model::setTimeInterval>, "setTimeInterval")); addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index a7738f60e..552302173 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -51,14 +51,25 @@ void LLM::interrupt() { runner->stop(); } -std::size_t getGeneratedTokenCount() const noexcept { +size_t LLM::getGeneratedTokenCount() const noexcept { + if (!runner || !runner->is_loaded()) { + return 0; + } return runner->stats_.num_generated_tokens; } -std::size_t LLM::getMemoryLowerBound() const noexcept { +size_t LLM::getMemoryLowerBound() const noexcept { return memorySizeLowerBound; } +void LLM::setCountInterval(size_t countInterval) { + runner->set_count_interval(countInterval); +} + +void LLM::setTimeInterval(size_t timeInterval) { + runner->set_time_interval(timeInterval); +} + void LLM::unload() noexcept { runner.reset(nullptr); } } // namespace rnexecutorch::models::llm diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index bf2f3ec81..fbaef0c3c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -23,6 +23,8 @@ class LLM { void unload() noexcept; size_t getGeneratedTokenCount() const noexcept; size_t getMemoryLowerBound() const noexcept; + void setCountInterval(size_t countInterval); + void setTimeInterval(size_t timeInterval); private: size_t memorySizeLowerBound; diff --git a/packages/react-native-executorch/common/runner/runner.cpp b/packages/react-native-executorch/common/runner/runner.cpp index 55ef36fdd..5880c23ea 100644 --- a/packages/react-native-executorch/common/runner/runner.cpp +++ b/packages/react-native-executorch/common/runner/runner.cpp @@ -271,4 +271,13 @@ void Runner::stop() { ET_LOG(Error, "Token generator is not loaded, cannot stop"); } } + +void Runner::set_count_interval(size_t count_interval) { + text_token_generator_->set_count_interval(count_interval); +} + +void Runner::set_time_interval(size_t time_interval) { + text_token_generator_->set_time_interval(time_interval); +} + } // namespace example diff --git a/packages/react-native-executorch/common/runner/runner.h b/packages/react-native-executorch/common/runner/runner.h index 11cb1796b..5b75ba6bc 100644 --- a/packages/react-native-executorch/common/runner/runner.h +++ b/packages/react-native-executorch/common/runner/runner.h @@ -43,6 +43,8 @@ class Runner : public executorch::extension::llm::IRunner { stats_callback = {}, bool echo = true, bool warming = false); ::executorch::runtime::Error warmup(const std::string &prompt); + void set_count_interval(size_t count_interval); + void set_time_interval(size_t time_interval); void stop(); ::executorch::extension::llm::Stats stats_; @@ -61,8 +63,6 @@ class Runner : public executorch::extension::llm::IRunner { std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_; std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> text_token_generator_; - - // stats }; } // namespace example diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index a09f82a34..103f2fa79 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -79,7 +79,7 @@ class TextTokenGenerator { from_blob(&pos, {1}, executorch::aten::ScalarType::Long); should_stop_ = false; - + timestamp_ = std::chrono::high_resolution_clock::now(); // Generate our tokens while (pos < seq_len - 1) { // Run the model @@ -110,11 +110,16 @@ class TextTokenGenerator { token_cache.push_back(static_cast(cur_token)); const std::string cache_decoded = tokenizer_->Decode(token_cache); + const auto timeIntervalElapsed = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timestamp_) > + time_interval_; + const auto countIntervalElapsed = token_cache.size() > count_interval_; + const auto eos_reached = eos_ids_->find(cur_token) != eos_ids_->end(); + if (!cache_decoded.ends_with("�") && !cache_decoded.ends_with(" �") && - (token_cache.size() > 10 || - std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - timestamp_) > - interval_)) { + (countIntervalElapsed || timeIntervalElapsed || should_stop_ || + eos_reached)) { token_callback(cache_decoded); token_cache.clear(); timestamp_ = std::chrono::high_resolution_clock::now(); @@ -125,7 +130,7 @@ class TextTokenGenerator { } // data-dependent terminating condition: we have n_eos_ number of EOS - if (eos_ids_->find(cur_token) != eos_ids_->end()) { + if (eos_reached) { printf("\n"); ET_LOG(Info, "\nReached to the end of generation"); break; @@ -139,12 +144,21 @@ class TextTokenGenerator { */ inline void stop() { should_stop_ = true; } + void set_count_interval(size_t count_interval) { + count_interval_ = count_interval; + } + + void set_time_interval(size_t time_interval) { + time_interval_ = std::chrono::milliseconds(time_interval); + } + private: tokenizers::Tokenizer *tokenizer_; TextDecoderRunner *text_decoder_runner_; std::unique_ptr> eos_ids_; bool use_kv_cache_; - std::chrono::milliseconds interval_{120}; + size_t count_interval_{10}; + std::chrono::milliseconds time_interval_{120}; std::chrono::high_resolution_clock::time_point timestamp_; // state machine diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index a1d2fb103..718972ed8 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -132,13 +132,23 @@ export class LLMController { this.nativeModule = global.loadLLM(modelPath, tokenizerPath); this.isReadyCallback(true); this.onToken = (data: string) => { + if (!data) { + return; + } + if ( - !data || - (SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig && - data === this.tokenizerConfig.eos_token) || - (SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && - data === this.tokenizerConfig.pad_token) + SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig && + data.indexOf(this.tokenizerConfig.eos_token) >= 0 ) { + data = data.replaceAll(this.tokenizerConfig.eos_token, ''); + } + if ( + SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && + data.indexOf(this.tokenizerConfig.pad_token) >= 0 + ) { + data = data.replaceAll(this.tokenizerConfig.pad_token, ''); + } + if (data.length === 0) { return; } @@ -206,6 +216,11 @@ export class LLMController { this.nativeModule.interrupt(); } + public getGeneratedTokenCount(): number { + console.log('kappa', this.nativeModule); + return this.nativeModule.getGeneratedTokenCount(); + } + public async generate(messages: Message[], tools?: LLMTool[]) { if (!this._isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); @@ -302,4 +317,12 @@ export class LLMController { }); return result; } + + public setCountInterval(countInteval: number) { + this.nativeModule.setCountInterval(countInteval); + } + + public setTimeInterval(timeInteval: number) { + this.nativeModule.setTimeInterval(timeInteval); + } } diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index df34ab333..63319bab8 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -114,6 +114,22 @@ export const useLLM = ({ [controllerInstance] ); + const getGeneratedTokenCount = useCallback( + () => controllerInstance.getGeneratedTokenCount(), + [controllerInstance] + ); + + const setCountInterval = useCallback( + (countInterval: number) => + controllerInstance.setCountInterval(countInterval), + [controllerInstance] + ); + + const setTimeInterval = useCallback( + (timeInterval: number) => controllerInstance.setTimeInterval(timeInterval), + [controllerInstance] + ); + return { messageHistory, response, @@ -122,6 +138,9 @@ export const useLLM = ({ isGenerating, downloadProgress, error, + getGeneratedTokenCount: getGeneratedTokenCount, + setTimeInterval: setTimeInterval, + setCountInterval: setCountInterval, configure: configure, generate: generate, sendMessage: sendMessage, diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 776e1c9f5..eebb84e82 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -6,6 +6,9 @@ export interface LLMType { isGenerating: boolean; downloadProgress: number; error: string | null; + getGeneratedTokenCount: () => number; + setTimeInterval: (timeInterval: number) => void; + setCountInterval: (countInterval: number) => void; configure: ({ chatConfig, toolsConfig, From 26c512e6222678b1367a6c4ab6d306b17f44efc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Fri, 26 Sep 2025 16:21:54 +0200 Subject: [PATCH 4/7] small refactor, added docs --- .../01-natural-language-processing/useLLM.md | 48 +++++++++++++------ .../LLMModule.md | 44 +++++++++++------ .../src/controllers/LLMController.ts | 19 ++++---- .../natural_language_processing/useLLM.ts | 23 ++++----- .../natural_language_processing/LLMModule.ts | 12 ++++- .../react-native-executorch/src/types/llm.ts | 7 ++- 6 files changed, 97 insertions(+), 56 deletions(-) diff --git a/docs/docs/02-hooks/01-natural-language-processing/useLLM.md b/docs/docs/02-hooks/01-natural-language-processing/useLLM.md index b9b462a96..549b9bf80 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useLLM.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useLLM.md @@ -60,20 +60,21 @@ For more information on loading resources, take a look at [loading models](../.. ### Returns -| Field | Type | Description | -| ------------------ | --------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- | -| `generate()` | `(messages: Message[], tools?: LLMTool[]) => Promise` | Runs model to complete chat passed in `messages` argument. It doesn't manage conversation context. | -| `interrupt()` | `() => void` | Function to interrupt the current inference. | -| `response` | `string` | State of the generated response. This field is updated with each token generated by the model. | -| `token` | `string` | The most recently generated token. | -| `isReady` | `boolean` | Indicates whether the model is ready. | -| `isGenerating` | `boolean` | Indicates whether the model is currently generating a response. | -| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1, indicating the extent of the model file retrieval. | -| `error` | string | null | Contains the error message if the model failed to load. | -| `configure` | `({ chatConfig?: Partial, toolsConfig?: ToolsConfig }) => void` | Configures chat and tool calling. See more details in [configuring the model](#configuring-the-model). | -| `sendMessage` | `(message: string) => Promise` | Function to add user message to conversation. After model responds, `messageHistory` will be updated with both user message and model response. | -| `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion `messageHistory` will be updated. | -| `messageHistory` | `Message[]` | History containing all messages in conversation. This field is updated after model responds to `sendMessage`. | +| Field | Type | Description | +| ------------------------ | -------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- | +| `generate()` | `(messages: Message[], tools?: LLMTool[]) => Promise` | Runs model to complete chat passed in `messages` argument. It doesn't manage conversation context. | +| `interrupt()` | `() => void` | Function to interrupt the current inference. | +| `response` | `string` | State of the generated response. This field is updated with each token generated by the model. | +| `token` | `string` | The most recently generated token. | +| `isReady` | `boolean` | Indicates whether the model is ready. | +| `isGenerating` | `boolean` | Indicates whether the model is currently generating a response. | +| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1, indicating the extent of the model file retrieval. | +| `error` | string | null | Contains the error message if the model failed to load. | +| `configure` | `({chatConfig?: Partial, toolsConfig?: ToolsConfig, generationConfig?: GenerationConfig}) => void` | Configures chat and tool calling. See more details in [configuring the model](#configuring-the-model). | +| `sendMessage` | `(message: string) => Promise` | Function to add user message to conversation. After model responds, `messageHistory` will be updated with both user message and model response. | +| `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion `messageHistory` will be updated. | +| `messageHistory` | `Message[]` | History containing all messages in conversation. This field is updated after model responds to `sendMessage`. | +| `getGeneratedTokenCount` | `() => number` | Returns the number of tokens generated in the last response. |
Type definitions @@ -102,9 +103,11 @@ interface LLMType { configure: ({ chatConfig, toolsConfig, + generationConfig, }: { chatConfig?: Partial; toolsConfig?: ToolsConfig; + generationConfig?: GenerationConfig; }) => void; generate: (messages: Message[], tools?: LLMTool[]) => Promise; sendMessage: (message: string) => Promise; @@ -138,6 +141,11 @@ interface ToolCall { arguments: Object; } +interface GenerationConfig { + outputTokenBatchSize: number; + batchTimeInterval: number; +} + type LLMTool = Object; ``` @@ -147,7 +155,7 @@ type LLMTool = Object; You can use functions returned from this hooks in two manners: -1. Functional/pure - we will not keep any state for you. You'll need to keep conversation history and handle function calling yourself. Use `generate` (and rarely `forward`) and `response`. Note that you don't need to run `configure` to use those. Furthermore, it will not have any effect on those functions. +1. Functional/pure - we will not keep any state for you. You'll need to keep conversation history and handle function calling yourself. Use `generate` (and rarely `forward`) and `response`. Note that you don't need to run `configure` to use those. Furthermore, `chatConfig` and `toolsConfig` will not have any effect on those functions. 2. Managed/stateful - we will manage conversation state. Tool calls will be parsed and called automatically after passing appropriate callbacks. See more at [managed LLM chat](#managed-llm-chat). @@ -267,6 +275,12 @@ To configure model (i.e. change system prompt, load initial conversation history - **`displayToolCalls`** - If set to true, JSON tool calls will be displayed in chat. If false, only answers will be displayed. +**`generationConfig`** - Object configuring generation settings, currently only output token batching. + +- **`outputTokenBatchSize`** - Soft upper limit on the number of tokens in each token batch (in certain cases there can be more tokens in given batch, i.e. when the batch would end with special emoji join character). + +- **`batchTimeInterval`** - Upper limit on the time interval between consecutive token batches. + ### Sending a message In order to send a message to the model, one can use the following code: @@ -459,6 +473,10 @@ The response should include JSON: } ``` +## Token Batching + +Depending on selected model and the user's device generation speed can be above 60 tokens per second. If the `tokenCallback` triggers rerenders and is invoked on every single token it can significantly decrease the app's performance. To alleviate this and help improve performance we implement token batching. To configure this you need to call `configure` method and pass `generationConfig`. Inside you can set two parameters `outputTokenBatchSize` and `batchTimeInterval`. They set the size of the batch before tokens are emitted and the maximum time interval between consecutive batches respectively. Each batch is emitted if either `timeInterval` elapses since last batch or `countInterval` number of tokens are generated. This allows for smooth generation even if model lags during generation. Default parameters are set to 10 tokens and 80ms for time interval (~12 batches per second). + ## Available models | Model Family | Sizes | Quantized | diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md index 14d3b4d74..6d20f8171 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md @@ -30,18 +30,19 @@ llm.delete(); ### Methods -| Method | Type | Description | -| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `constructor` | `({tokenCallback?: (token: string) => void, responseCallback?: (response: string) => void, messageHistoryCallback?: (messageHistory: Message[]) => void})` | Creates a new instance of LLMModule with optional callbacks. | -| `load` | `(model: { modelSource: ResourceSource; tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model. | -| `setTokenCallback` | `{tokenCallback: (token: string) => void}) => void` | Sets new token callback. | -| `generate` | `(messages: Message[], tools?: LLMTool[]) => Promise` | Runs model to complete chat passed in `messages` argument. It doesn't manage conversation context. | -| `forward` | `(input: string) => Promise` | Runs model inference with raw input string. You need to provide entire conversation and prompt (in correct format and with special tokens!) in input string to this method. It doesn't manage conversation context. It is intended for users that need access to the model itself without any wrapper. If you want a simple chat with model the consider using`sendMessage` | -| `configure` | `({chatConfig?: Partial, toolsConfig?: ToolsConfig}) => void` | Configures chat and tool calling. See more details in [configuring the model](#configuring-the-model). | -| `sendMessage` | `(message: string) => Promise` | Method to add user message to conversation. After model responds it will call `messageHistoryCallback()`containing both user message and model response. It also returns them. | -| `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion it will call `messageHistoryCallback()` containing new history. It also returns it. | -| `delete` | `() => void` | Method to delete the model from memory. Note you cannot delete model while it's generating. You need to interrupt it first and make sure model stopped generation. | -| `interrupt` | `() => void` | Interrupts model generation. It may return one more token after interrupt. | +| Method | Type | Description | +| ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `constructor` | `({tokenCallback?: (token: string) => void, responseCallback?: (response: string) => void, messageHistoryCallback?: (messageHistory: Message[]) => void})` | Creates a new instance of LLMModule with optional callbacks. | +| `load` | `(model: { modelSource: ResourceSource; tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model. | +| `setTokenCallback` | `{tokenCallback: (token: string) => void}) => void` | Sets new token callback invoked on every token batch. | +| `generate` | `(messages: Message[], tools?: LLMTool[]) => Promise` | Runs model to complete chat passed in `messages` argument. It doesn't manage conversation context. | +| `forward` | `(input: string) => Promise` | Runs model inference with raw input string. You need to provide entire conversation and prompt (in correct format and with special tokens!) in input string to this method. It doesn't manage conversation context. It is intended for users that need access to the model itself without any wrapper. If you want a simple chat with model the consider using `sendMessage` | +| `configure` | `({chatConfig?: Partial, toolsConfig?: ToolsConfig, generationConfig?: GenerationConfig}) => void` | Configures chat and tool calling and generation settings. See more details in [configuring the model](#configuring-the-model). | +| `sendMessage` | `(message: string) => Promise` | Method to add user message to conversation. After model responds it will call `messageHistoryCallback()`containing both user message and model response. It also returns them. | +| `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion it will call `messageHistoryCallback()` containing new history. It also returns it. | +| `delete` | `() => void` | Method to delete the model from memory. Note you cannot delete model while it's generating. You need to interrupt it first and make sure model stopped generation. | +| `interrupt` | `() => void` | Interrupts model generation. It may return one more token after interrupt. | +| `getGeneratedTokenCount` | `() => number` | Returns the number of tokens generated in the last response. |
Type definitions @@ -68,6 +69,11 @@ interface ToolsConfig { displayToolCalls?: boolean; } +interface GenerationConfig { + outputTokenBatchSize: number; + batchTimeInterval: number; +} + interface ToolCall { toolName: string; arguments: Object; @@ -124,10 +130,14 @@ To subscribe to the token generation event, you can pass `tokenCallback` or `mes In order to interrupt the model, you can use the `interrupt` method. +## Token Batching + +Depending on selected model and the user's device generation speed can be above 60 tokens per second. If the `tokenCallback` triggers rerenders and is invoked on every single token it can significantly decrease the app's performance. To alleviate this and help improve performance we implement token batching. To configure this you need to call `configure` method and pass `generationConfig`. Inside you can set two parameters `outputTokenBatchSize` and `batchTimeInterval`. They set the size of the batch before tokens are emitted and the maximum time interval between consecutive batches respectively. Each batch is emitted if either `timeInterval` elapses since last batch or `countInterval` number of tokens are generated. This allows for smooth generation even if model lags during generation. Default parameters are set to 10 tokens and 80ms for time interval (~12 batches per second). + ## Configuring the model -To configure model (i.e. change system prompt, load initial conversation history or manage tool calling) you can use -`configure` method. It is only applied to managed chats i.e. when using `sendMessage` (see: [Functional vs managed](../../02-hooks/01-natural-language-processing/useLLM.md#functional-vs-managed)) It accepts object with following fields: +To configure model (i.e. change system prompt, load initial conversation history or manage tool calling, set generation settings) you can use +`configure` method. `chatConfig` and `toolsConfig` is only applied to managed chats i.e. when using `sendMessage` (see: [Functional vs managed](../../02-hooks/01-natural-language-processing/useLLM.md#functional-vs-managed)) It accepts object with following fields: **`chatConfig`** - Object configuring chat management: @@ -145,6 +155,12 @@ To configure model (i.e. change system prompt, load initial conversation history - **`displayToolCalls`** - If set to true, JSON tool calls will be displayed in chat. If false, only answers will be displayed. +**`generationConfig`** - Object configuring generation settings, currently only output token batching. + +- **`outputTokenBatchSize`** - Soft upper limit on the number of tokens in each token batch (in certain cases there can be more tokens in given batch, i.e. when the batch would end with special emoji join character). + +- **`batchTimeInterval`** - Upper limit on the time interval between consecutive token batches. + ## Deleting the model from memory To delete the model from memory, you can use the `delete` method. diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index 718972ed8..f189df8b1 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -6,6 +6,7 @@ import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults'; import { readAsStringAsync } from 'expo-file-system'; import { ChatConfig, + GenerationConfig, LLMTool, Message, SPECIAL_TOKENS, @@ -168,13 +169,22 @@ export class LLMController { public configure({ chatConfig, toolsConfig, + generationConfig, }: { chatConfig?: Partial; toolsConfig?: ToolsConfig; + generationConfig?: GenerationConfig; }) { this.chatConfig = { ...DEFAULT_CHAT_CONFIG, ...chatConfig }; this.toolsConfig = toolsConfig; + if (generationConfig?.outputTokenBatchSize) { + this.nativeModule.setCountInterval(generationConfig.outputTokenBatchSize); + } + if (generationConfig?.batchTimeInterval) { + this.nativeModule.setTimeInterval(generationConfig.batchTimeInterval); + } + // reset inner state when loading new configuration this.responseCallback(''); this.messageHistoryCallback(this.chatConfig.initialMessageHistory); @@ -217,7 +227,6 @@ export class LLMController { } public getGeneratedTokenCount(): number { - console.log('kappa', this.nativeModule); return this.nativeModule.getGeneratedTokenCount(); } @@ -317,12 +326,4 @@ export class LLMController { }); return result; } - - public setCountInterval(countInteval: number) { - this.nativeModule.setCountInterval(countInteval); - } - - public setTimeInterval(timeInteval: number) { - this.nativeModule.setTimeInterval(timeInteval); - } } diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 63319bab8..25cad7969 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -2,6 +2,7 @@ import { useCallback, useEffect, useState } from 'react'; import { ResourceSource } from '../../types/common'; import { ChatConfig, + GenerationConfig, LLMTool, LLMType, Message, @@ -81,10 +82,17 @@ export const useLLM = ({ ({ chatConfig, toolsConfig, + generationConfig, }: { chatConfig?: Partial; toolsConfig?: ToolsConfig; - }) => controllerInstance.configure({ chatConfig, toolsConfig }), + generationConfig?: GenerationConfig; + }) => + controllerInstance.configure({ + chatConfig, + toolsConfig, + generationConfig, + }), [controllerInstance] ); @@ -119,17 +127,6 @@ export const useLLM = ({ [controllerInstance] ); - const setCountInterval = useCallback( - (countInterval: number) => - controllerInstance.setCountInterval(countInterval), - [controllerInstance] - ); - - const setTimeInterval = useCallback( - (timeInterval: number) => controllerInstance.setTimeInterval(timeInterval), - [controllerInstance] - ); - return { messageHistory, response, @@ -139,8 +136,6 @@ export const useLLM = ({ downloadProgress, error, getGeneratedTokenCount: getGeneratedTokenCount, - setTimeInterval: setTimeInterval, - setCountInterval: setCountInterval, configure: configure, generate: generate, sendMessage: sendMessage, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts index 299330807..778e4bf52 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts @@ -1,6 +1,12 @@ import { LLMController } from '../../controllers/LLMController'; import { ResourceSource } from '../../types/common'; -import { ChatConfig, LLMTool, Message, ToolsConfig } from '../../types/llm'; +import { + ChatConfig, + GenerationConfig, + LLMTool, + Message, + ToolsConfig, +} from '../../types/llm'; export class LLMModule { private controller: LLMController; @@ -46,11 +52,13 @@ export class LLMModule { configure({ chatConfig, toolsConfig, + generationConfig, }: { chatConfig?: Partial; toolsConfig?: ToolsConfig; + generationConfig?: GenerationConfig; }) { - this.controller.configure({ chatConfig, toolsConfig }); + this.controller.configure({ chatConfig, toolsConfig, generationConfig }); } async forward(input: string): Promise { diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index eebb84e82..a0e4e8af7 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -7,8 +7,6 @@ export interface LLMType { downloadProgress: number; error: string | null; getGeneratedTokenCount: () => number; - setTimeInterval: (timeInterval: number) => void; - setCountInterval: (countInterval: number) => void; configure: ({ chatConfig, toolsConfig, @@ -50,6 +48,11 @@ export interface ToolsConfig { displayToolCalls?: boolean; } +export interface GenerationConfig { + outputTokenBatchSize?: number; + batchTimeInterval?: number; +} + export const SPECIAL_TOKENS = { BOS_TOKEN: 'bos_token', EOS_TOKEN: 'eos_token', From cd47f968b35cfbb661e9e4b7c465feb0a3d1626e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Fri, 26 Sep 2025 16:47:22 +0200 Subject: [PATCH 5/7] fixed bug where first token was emitted before batch --- packages/react-native-executorch/common/runner/runner.cpp | 3 --- .../common/runner/text_token_generator.h | 8 ++------ packages/react-native-executorch/src/types/llm.ts | 2 ++ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/runner/runner.cpp b/packages/react-native-executorch/common/runner/runner.cpp index 5880c23ea..830d8509c 100644 --- a/packages/react-native-executorch/common/runner/runner.cpp +++ b/packages/react-native-executorch/common/runner/runner.cpp @@ -218,9 +218,6 @@ Error Runner::generate(const std::string &prompt, RUNNER_ET_LOG(warmup, "RSS after prompt prefill: %f MiB (0 if unsupported)", llm::get_rss_bytes() / 1024.0 / 1024.0); - if (cur_decoded != "�") { - wrapped_callback(cur_decoded); - } // start the main loop prompt_tokens_uint64.push_back(cur_token); int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate( diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index 103f2fa79..c8374821c 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -55,12 +55,8 @@ class TextTokenGenerator { uint64_t cur_token = tokens.back(); // cache to keep tokens if they were decoded into illegal character std::vector token_cache; - // if first token after prefill was part of multi-token character we need to - // add this to cache here - if (tokenizer_->Decode( - std::vector{static_cast(cur_token)}) == "�") { - token_cache.push_back(static_cast(cur_token)); - } + // add first token after prefill to cache here + token_cache.push_back(static_cast(cur_token)); if (use_kv_cache_) { // hard code these to size 1 as kv cache is locked to static size right diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index a0e4e8af7..2d0a60c98 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -10,9 +10,11 @@ export interface LLMType { configure: ({ chatConfig, toolsConfig, + generationConfig, }: { chatConfig?: Partial; toolsConfig?: ToolsConfig; + generationConfig?: GenerationConfig; }) => void; generate: (messages: Message[], tools?: LLMTool[]) => Promise; sendMessage: (message: string) => Promise; From 85f93288e029e68904dc53ac18b576274c35158c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Wed, 1 Oct 2025 11:30:54 +0200 Subject: [PATCH 6/7] review changes --- docs/docs/02-hooks/01-natural-language-processing/useLLM.md | 2 +- .../01-natural-language-processing/LLMModule.md | 2 +- .../common/rnexecutorch/models/llm/LLM.cpp | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/docs/02-hooks/01-natural-language-processing/useLLM.md b/docs/docs/02-hooks/01-natural-language-processing/useLLM.md index 549b9bf80..37ce45b55 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useLLM.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useLLM.md @@ -475,7 +475,7 @@ The response should include JSON: ## Token Batching -Depending on selected model and the user's device generation speed can be above 60 tokens per second. If the `tokenCallback` triggers rerenders and is invoked on every single token it can significantly decrease the app's performance. To alleviate this and help improve performance we implement token batching. To configure this you need to call `configure` method and pass `generationConfig`. Inside you can set two parameters `outputTokenBatchSize` and `batchTimeInterval`. They set the size of the batch before tokens are emitted and the maximum time interval between consecutive batches respectively. Each batch is emitted if either `timeInterval` elapses since last batch or `countInterval` number of tokens are generated. This allows for smooth generation even if model lags during generation. Default parameters are set to 10 tokens and 80ms for time interval (~12 batches per second). +Depending on selected model and the user's device generation speed can be above 60 tokens per second. If the `tokenCallback` triggers rerenders and is invoked on every single token it can significantly decrease the app's performance. To alleviate this and help improve performance we've implemented token batching. To configure this you need to call `configure` method and pass `generationConfig`. Inside you can set two parameters `outputTokenBatchSize` and `batchTimeInterval`. They set the size of the batch before tokens are emitted and the maximum time interval between consecutive batches respectively. Each batch is emitted if either `timeInterval` elapses since last batch or `countInterval` number of tokens are generated. This allows for smooth generation even if model lags during generation. Default parameters are set to 10 tokens and 80ms for time interval (~12 batches per second). ## Available models diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md index 6d20f8171..00af4d561 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md @@ -132,7 +132,7 @@ In order to interrupt the model, you can use the `interrupt` method. ## Token Batching -Depending on selected model and the user's device generation speed can be above 60 tokens per second. If the `tokenCallback` triggers rerenders and is invoked on every single token it can significantly decrease the app's performance. To alleviate this and help improve performance we implement token batching. To configure this you need to call `configure` method and pass `generationConfig`. Inside you can set two parameters `outputTokenBatchSize` and `batchTimeInterval`. They set the size of the batch before tokens are emitted and the maximum time interval between consecutive batches respectively. Each batch is emitted if either `timeInterval` elapses since last batch or `countInterval` number of tokens are generated. This allows for smooth generation even if model lags during generation. Default parameters are set to 10 tokens and 80ms for time interval (~12 batches per second). +Depending on selected model and the user's device generation speed can be above 60 tokens per second. If the `tokenCallback` triggers rerenders and is invoked on every single token it can significantly decrease the app's performance. To alleviate this and help improve performance we've implemented token batching. To configure this you need to call `configure` method and pass `generationConfig`. Inside you can set two parameters `outputTokenBatchSize` and `batchTimeInterval`. They set the size of the batch before tokens are emitted and the maximum time interval between consecutive batches respectively. Each batch is emitted if either `timeInterval` elapses since last batch or `countInterval` number of tokens are generated. This allows for smooth generation even if model lags during generation. Default parameters are set to 10 tokens and 80ms for time interval (~12 batches per second). ## Configuring the model diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 552302173..d8d4f9819 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -63,10 +63,16 @@ size_t LLM::getMemoryLowerBound() const noexcept { } void LLM::setCountInterval(size_t countInterval) { + if (!runner || !runner->is_loaded()) { + throw std::runtime_error("Can't configure a model that's not loaded!"); + } runner->set_count_interval(countInterval); } void LLM::setTimeInterval(size_t timeInterval) { + if (!runner || !runner->is_loaded()) { + throw std::runtime_error("Can't configure a model that's not loaded!"); + } runner->set_time_interval(timeInterval); } From a60bb331a907dd0817099bd95c4b97d3b3ad3d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Kopci=C5=84ski?= Date: Wed, 1 Oct 2025 11:46:01 +0200 Subject: [PATCH 7/7] review changes --- .../common/runner/text_token_generator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index c8374821c..e616ad7f3 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -111,9 +111,9 @@ class TextTokenGenerator { std::chrono::high_resolution_clock::now() - timestamp_) > time_interval_; const auto countIntervalElapsed = token_cache.size() > count_interval_; - const auto eos_reached = eos_ids_->find(cur_token) != eos_ids_->end(); + const auto eos_reached = eos_ids_->contains(cur_token); - if (!cache_decoded.ends_with("�") && !cache_decoded.ends_with(" �") && + if (!cache_decoded.ends_with("�") && (countIntervalElapsed || timeIntervalElapsed || should_stop_ || eos_reached)) { token_callback(cache_decoded);