-
Notifications
You must be signed in to change notification settings - Fork 52
feat: text embeddings #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
cc85723
Add text embedding model (iOS)
c60bf18
Add text embedding model (Android)
5dac0bb
Add useTextEmbeddings hook
bfbba7c
Code improvements
28f414b
Fix download progress calculation
014bbd7
Format code
83fd07c
Add text embeddings URLs and change types
f2ad795
Add text embeddings demo app
1af1680
Update ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.mm
jakmro eb92c35
Add text embeddings docs and benchmarks
83f1a1a
Add model status information and fix styling
68a658b
docs: Remove Llama Export, refactor docs (#174)
jakmro ba25300
chore: @jakmro/native code refactor (#177)
jakmro f58f8f1
chore: @jakmro/code refactor (#172)
jakmro a9c8e0f
docs: Add tokenizer documentation (#202)
jakmro 568007b
Repair S2T
8c3340f
Repair S2T v2
4018133
Repair S2T v3
297e839
Fix android LLM imports
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
android/src/main/java/com/swmansion/rnexecutorch/TextEmbeddings.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| package com.swmansion.rnexecutorch | ||
|
|
||
| import com.facebook.react.bridge.Promise | ||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.facebook.react.bridge.WritableNativeArray | ||
| import com.swmansion.rnexecutorch.models.textEmbeddings.TextEmbeddingsModel | ||
| import com.swmansion.rnexecutorch.utils.ETError | ||
|
|
||
| class TextEmbeddings( | ||
| reactContext: ReactApplicationContext, | ||
| ) : NativeTextEmbeddingsSpec(reactContext) { | ||
| private lateinit var textEmbeddingsModel: TextEmbeddingsModel | ||
|
|
||
| companion object { | ||
| const val NAME = "TextEmbeddings" | ||
| } | ||
|
|
||
| override fun loadModule( | ||
| modelSource: String, | ||
| tokenizerSource: String, | ||
| promise: Promise, | ||
| ) { | ||
| try { | ||
| textEmbeddingsModel = TextEmbeddingsModel(reactApplicationContext) | ||
|
|
||
| textEmbeddingsModel.loadModel(modelSource) | ||
| textEmbeddingsModel.loadTokenizer(tokenizerSource) | ||
|
|
||
| promise.resolve(0) | ||
| } catch (e: Exception) { | ||
| promise.reject(e.message!!, ETError.InvalidModelSource.toString()) | ||
| } | ||
| } | ||
|
|
||
| override fun forward( | ||
| input: String, | ||
| promise: Promise, | ||
| ) { | ||
| try { | ||
| val output = textEmbeddingsModel.runModel(input) | ||
| val writableArray = WritableNativeArray() | ||
| output.forEach { writableArray.pushDouble(it) } | ||
|
|
||
| promise.resolve(writableArray) | ||
| } catch (e: Exception) { | ||
| promise.reject(e.message!!, e.message) | ||
| } | ||
| } | ||
|
|
||
| override fun getName(): String = NAME | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
...oid/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsModel.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| package com.swmansion.rnexecutorch.models.textEmbeddings | ||
|
|
||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.swmansion.rnexecutorch.models.BaseModel | ||
| import org.pytorch.executorch.EValue | ||
| import org.pytorch.executorch.HuggingFaceTokenizer | ||
| import org.pytorch.executorch.Tensor | ||
|
|
||
| class TextEmbeddingsModel( | ||
| reactApplicationContext: ReactApplicationContext, | ||
| ) : BaseModel<String, DoubleArray>(reactApplicationContext) { | ||
| private lateinit var tokenizer: HuggingFaceTokenizer | ||
|
|
||
| fun loadTokenizer(tokenizerSource: String) { | ||
| tokenizer = HuggingFaceTokenizer(tokenizerSource) | ||
| } | ||
|
|
||
| fun preprocess(input: String): Array<LongArray> { | ||
| val inputIds = tokenizer.encode(input).map { it.toLong() }.toLongArray() | ||
| val attentionMask = inputIds.map { if (it != 0L) 1L else 0L }.toLongArray() | ||
| return arrayOf(inputIds, attentionMask) // Shape: [2, max_length] | ||
| } | ||
|
|
||
| fun postprocess( | ||
| modelOutput: FloatArray, // [max_length * embedding_dim] | ||
| attentionMask: LongArray, // [max_length] | ||
| ): DoubleArray { | ||
| val modelOutputDouble = modelOutput.map { it.toDouble() }.toDoubleArray() | ||
| val embeddings = TextEmbeddingsUtils.meanPooling(modelOutputDouble, attentionMask) | ||
| return TextEmbeddingsUtils.normalize(embeddings) | ||
| } | ||
|
|
||
| override fun runModel(input: String): DoubleArray { | ||
| val modelInput = preprocess(input) | ||
| val inputsIds = modelInput[0] | ||
| val attentionMask = modelInput[1] | ||
|
|
||
| val inputsIdsShape = longArrayOf(1, inputsIds.size.toLong()) | ||
| val attentionMaskShape = longArrayOf(1, attentionMask.size.toLong()) | ||
|
|
||
| val inputIdsEValue = EValue.from(Tensor.fromBlob(inputsIds, inputsIdsShape)) | ||
| val attentionMaskEValue = EValue.from(Tensor.fromBlob(attentionMask, attentionMaskShape)) | ||
|
|
||
| val modelOutput = forward(inputIdsEValue, attentionMaskEValue)[0].toTensor().dataAsFloatArray | ||
|
|
||
| return postprocess(modelOutput, attentionMask) | ||
| } | ||
| } | ||
37 changes: 37 additions & 0 deletions
37
...oid/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsUtils.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| package com.swmansion.rnexecutorch.models.textEmbeddings | ||
|
|
||
| import kotlin.math.sqrt | ||
|
|
||
| class TextEmbeddingsUtils { | ||
| companion object { | ||
| fun meanPooling( | ||
| modelOutput: DoubleArray, | ||
| attentionMask: LongArray, | ||
| ): DoubleArray { | ||
| val attentionMaskLength = attentionMask.size | ||
| val modelOutputLength = modelOutput.size | ||
| val embeddingDim = modelOutputLength / attentionMaskLength | ||
|
|
||
| val result = DoubleArray(embeddingDim) | ||
| var sumMask = attentionMask.sum().toDouble() | ||
| sumMask = maxOf(sumMask, 1e-9) | ||
|
|
||
| for (i in 0 until embeddingDim) { | ||
| var sum = 0.0 | ||
| for (j in 0 until attentionMaskLength) { | ||
| sum += modelOutput[j * embeddingDim + i] * attentionMask[j] | ||
| } | ||
| result[i] = sum / sumMask | ||
| } | ||
|
|
||
| return result | ||
| } | ||
|
|
||
| fun normalize(embeddings: DoubleArray): DoubleArray { | ||
| var sum = embeddings.sumOf { it * it } | ||
| sum = maxOf(sqrt(sum), 1e-9) | ||
|
|
||
| return embeddings.map { it / sum }.toDoubleArray() | ||
| } | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| { | ||
| "label": "Benchmarks", | ||
| "position": 8, | ||
| "position": 7, | ||
| "link": { | ||
| "type": "generated-index" | ||
| } | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| { | ||
| "label": "Computer Vision", | ||
| "position": 4, | ||
| "position": 3, | ||
| "link": { | ||
| "type": "generated-index" | ||
| } | ||
|
|
||
2 changes: 1 addition & 1 deletion
2
docs/docs/hookless-api/_category_.json → .../docs/executorch-bindings/_category_.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| { | ||
| "label": "Hookless API", | ||
| "label": "ExecuTorch Bindings", | ||
| "position": 5, | ||
| "link": { | ||
| "type": "generated-index" | ||
|
|
||
2 changes: 1 addition & 1 deletion
2
docs/docs/module-api/executorch-bindings.md → ...xecutorch-bindings/useExecutorchModule.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| --- | ||
| title: ExecuTorch Bindings | ||
| title: useExecutorchModule | ||
| sidebar_position: 1 | ||
| --- | ||
|
|
||
|
|
||
2 changes: 1 addition & 1 deletion
2
docs/docs/utils/_category_.json → docs/docs/faq/_category_.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| { | ||
| "label": "Utils", | ||
| "label": "FAQ", | ||
| "position": 7, | ||
| "link": { | ||
| "type": "generated-index" | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| --- | ||
| title: Frequently asked questions | ||
| sidebar_position: 1 | ||
| --- | ||
|
|
||
| This section is meant to answer some common community inquiries, especially regarding the ExecuTorch runtime or adding your own models. If you can't see an answer to your question, feel free to open up a [discussion](https://github.com/software-mansion/react-native-executorch/discussions/new/choose). | ||
|
|
||
| ### What models are supported? | ||
|
|
||
| Each hook documentation subpage (useClassification, useLLM, etc.) contains a supported models section, which lists the models that are runnable within the library with close to no setup. For running your custom models, refer to `ExecuTorchModule` or `useExecuTorchModule`. | ||
|
|
||
| ### How can I run my own AI model? | ||
|
|
||
| To run your own model, you need to directly access the underlying [ExecuTorch Module API](https://pytorch.org/executorch/stable/extension-module.html). We provide an experimental [React hook](../executorch-bindings/useExecutorchModule.md) along with a [TypeScript alternative](../typescript-api/ExecutorchModule.md), which serve as a way to use the aforementioned API without the need of diving into native code. In order to get a model in a format runnable by the runtime, you'll need to get your hands dirty with some ExecuTorch knowledge. For more guides on exporting models, please refer to the [ExecuTorch tutorials](https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html). Once you obtain your model in a `.pte` format, you can run it with `useExecuTorchModule` and `ExecuTorchModule`. | ||
|
|
||
| ### Can you do function calling with useLLM? | ||
|
|
||
| We currently don't provide an out-of-the-box solution for function calling, but modifying system prompts for Llama models should be enough for simple use cases. For more details, refer to [this comment](https://github.com/software-mansion/react-native-executorch/issues/173#issuecomment-2775082278) | ||
|
|
||
| ### Can I use React Native ExecuTorch in bare React Native apps? | ||
|
|
||
| To use the library, you need to install Expo Modules first. For a setup guide, refer to [this tutorial](https://docs.expo.dev/bare/installing-expo-modules/). This is because we use Expo File System under the hood to download and manage the model binaries. | ||
|
|
||
| ### Do you support the old architecture? | ||
|
|
||
| The old architecture is not supported and we're currently not planning to add support. | ||
|
|
||
| ### Can I run GGUF models using the library? | ||
|
|
||
| No, as of now ExecuTorch runtime doesn't provide a reliable way to use GGUF models, hence it is not possible. | ||
|
|
||
| ### Are the models leveraging GPU acceleration? | ||
|
|
||
| While it is possible to run some models using Core ML on iOS, which is a backend that utilizes CPU, GPU and ANE, we currently don't have many models exported to Core ML. For Android, the current state of GPU acceleration is pretty limited. As of now, there are attempts of running the models using a Vulkan backend. However the operator support is very limited meaning that the resulting performance is often inferior to XNNPACK. Hence, most of the models use XNNPACK, which is a highly optimized and mature CPU backend that runs on both Android and iOS. | ||
|
|
||
| ### Does this library support XNNPACK and Core ML? | ||
|
|
||
| Yes, all of the backends are linked, therefore the only thing that needs to be done on your end is to export the model with the backend that you're interested in using. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is max_length specified? I think mentioning it here would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_lengthis specified inside tokenizer.json