diff --git a/crates/spin-python-engine/src/lib.rs b/crates/spin-python-engine/src/lib.rs index 564e07d..ecbe58d 100644 --- a/crates/spin-python-engine/src/lib.rs +++ b/crates/spin-python-engine/src/lib.rs @@ -545,13 +545,65 @@ fn llm_infer( .map(LLMInferencingResult::from) } +#[pyo3::pyfunction] +fn generate_embeddings(model: &str, text: Vec) -> Result { + let model = match model { + "all-minilm-l6-v2" => llm::EmbeddingModel::AllMiniLmL6V2, + _ => llm::EmbeddingModel::Other(model), + }; + + let text = text.iter().map(|s| s.as_str()).collect::>(); + + llm::generate_embeddings(model, &text) + .map_err(Anyhow::from) + .map(LLMEmbeddingsResult::from) +} + +#[derive(Clone)] +#[pyo3::pyclass] +#[pyo3(name = "LLMEmbeddingsUsage")] +struct LLMEmbeddingsUsage { + #[pyo3(get)] + prompt_token_count: u32, +} + +impl From for LLMEmbeddingsUsage { + fn from(result: llm::EmbeddingsUsage) -> Self { + LLMEmbeddingsUsage { + prompt_token_count: result.prompt_token_count, + } + } +} + +#[derive(Clone)] +#[pyo3::pyclass] +#[pyo3(name = "LLMEmbeddingResult")] +struct LLMEmbeddingsResult { + #[pyo3(get)] + embeddings: Vec>, + #[pyo3(get)] + usage: LLMEmbeddingsUsage, +} + +impl From for LLMEmbeddingsResult { + fn from(result: llm::EmbeddingsResult) -> Self { + LLMEmbeddingsResult { + embeddings: result.embeddings, + usage: LLMEmbeddingsUsage::from(result.usage), + } + } +} + #[pyo3::pymodule] #[pyo3(name = "spin_llm")] fn spin_llm_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> { module.add_function(pyo3::wrap_pyfunction!(llm_infer, module)?)?; + module.add_function(pyo3::wrap_pyfunction!(generate_embeddings, module)?)?; module.add_class::()?; module.add_class::()?; - module.add_class::() + module.add_class::()?; + module.add_class::()?; + module.add_class::() } pub fn run_ctors() { diff --git a/examples/llm/app.py b/examples/llm/app.py index 29b2733..2cfc695 100644 --- a/examples/llm/app.py +++ b/examples/llm/app.py @@ -1,10 +1,18 @@ +import json from spin_http import Response -from spin_llm import llm_infer +from spin_llm import llm_infer, generate_embeddings def handle_request(request): prompt="You are a stand up comedy writer. Tell me a joke." result=llm_infer("llama2-chat", prompt) + + embeddings = generate_embeddings("all-minilm-l6-v2", ["hat", "cat", "bat"]) + + body = (f"joke: {result.text}\n\n" + f"embeddings: {json.dumps(embeddings.embeddings)}\n" + f"prompt token count: {embeddings.usage.prompt_token_count}") + return Response(200, {"content-type": "text/plain"}, - bytes(result.text, "utf-8")) + bytes(body, "utf-8")) diff --git a/examples/llm/spin.toml b/examples/llm/spin.toml index f7c2e82..846116b 100644 --- a/examples/llm/spin.toml +++ b/examples/llm/spin.toml @@ -8,7 +8,7 @@ version = "0.1.0" [[component]] id = "python-sdk-example" source = "app.wasm" -ai_models = ["llama2-chat"] +ai_models = ["llama2-chat", "all-minilm-l6-v2"] [component.trigger] route = "/..." [component.build]