Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Ollama support in transform function #106

Merged
merged 14 commits into from
Jun 24, 2024
54 changes: 54 additions & 0 deletions core/src/transformers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use anyhow::Result;
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
use url::Url;

use super::types::EmbeddingRequest;

pub struct OllamaInstance {
pub model_name: String,
pub instance: Ollama,
Expand All @@ -11,6 +13,8 @@ pub trait LLMFunctions {
fn new(model_name: String, url: String) -> Self;
#[allow(async_fn_in_trait)]
async fn generate_reponse(&self, prompt_text: String) -> Result<String, String>;
#[allow(async_fn_in_trait)]
async fn generate_embedding(&self, inputs: String)-> Result<Vec<f64>, String>;
}

impl LLMFunctions for OllamaInstance {
Expand Down Expand Up @@ -38,6 +42,13 @@ impl LLMFunctions for OllamaInstance {
Err(e) => Err(e.to_string()),
}
}
async fn generate_embedding(&self, input: String) -> Result<Vec<f64>, String>{
let embed = self.instance.generate_embeddings(self.model_name.clone(), input, None).await;
match embed{
Ok(res) => Ok(res.embeddings),
Err(e) => Err(e.to_string())
}
}
}

pub fn ollama_embedding_dim(model_name: &str) -> i32 {
Expand All @@ -46,3 +57,46 @@ pub fn ollama_embedding_dim(model_name: &str) -> i32 {
_ => 1536,
}
}

pub fn check_model_host(url: &str) -> Result<String, String>{
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e));

runtime.block_on(async {
let response = reqwest::get(url).await.unwrap();
match response.status(){
reqwest::StatusCode::OK => {
Ok(format!("Success! {:?}", response))
},
_ => {
Err(format!("Error! {:?}", response))
}
}
})
}

pub fn generate_embeddings(request: EmbeddingRequest) -> Result<Vec<Vec<f64>>>{
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e));

runtime.block_on(async {
let instance = OllamaInstance::new(request.payload.model, request.url);
let mut embeddings: Vec<Vec<f64>> = vec![];
for input in request.payload.input{
let response = instance.generate_embedding(input).await;
let embedding = match response{
Ok(embed) => embed,
Err(e) => panic!("Unable to generate embeddings.\nError: {:?}", e)
};
embeddings.push(embedding);
}
Ok(embeddings)
})

}
4 changes: 4 additions & 0 deletions core/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use chrono::serde::ts_seconds_option::deserialize as from_tsopt;
use log::info;
use serde::{Deserialize, Serialize};
use sqlx::types::chrono::Utc;
use sqlx::FromRow;
Expand Down Expand Up @@ -200,12 +201,15 @@ pub enum ModelError {
impl Model {
pub fn new(input: &str) -> Result<Self, ModelError> {
let mut parts: Vec<&str> = input.split('/').collect();
info!("{:?}", parts);
let missing_source = parts.len() != 2;
if missing_source && parts[0] == "text-embedding-ada-002" {
// for backwards compatibility, prepend "openai" to text-embedding-ada-2
parts.insert(0, "openai");
} else if missing_source && parts[0] == "all-MiniLM-L12-v2" {
parts.insert(0, "sentence-transformers");
} else if missing_source && parts[0] == "llama2" {
parts.insert(0, "ollama");
ChuckHend marked this conversation as resolved.
Show resolved Hide resolved
} else if missing_source {
return Err(ModelError::InvalidFormat(input.to_string()));
}
Expand Down
17 changes: 16 additions & 1 deletion extension/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::util;

use anyhow::{Context, Result};
use pgrx::prelude::*;
use vectorize_core::transformers::ollama::check_model_host;
use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta};

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -72,7 +73,21 @@ pub fn init_table(
.context("transformer does not exist")?;
}
ModelSource::Ollama => {
error!("Ollama not implemented for search yet");
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
Some(k) => k,
None => {
error!("failed to get Ollama url from GUC");
}
};
let res = check_model_host(&url);
match res{
Ok(_) =>{
info!("Model host active!")
},
Err(e) => {
error!("Error with model host: {:?}", e)
}
}
}
}

Expand Down
30 changes: 28 additions & 2 deletions extension/src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use vectorize_core::transformers::http_handler::openai_embedding_request;
use vectorize_core::transformers::openai::OPENAI_EMBEDDING_URL;
use vectorize_core::transformers::types::{EmbeddingPayload, EmbeddingRequest};
use vectorize_core::types::{Model, ModelSource};
use vectorize_core::transformers::ollama::generate_embeddings;

pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> Vec<Vec<f64>> {
let runtime = tokio::runtime::Builder::new_current_thread()
Expand Down Expand Up @@ -52,12 +53,37 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
api_key: api_key.map(|s| s.to_string()),
}
}
ModelSource::Ollama => error!("Ollama transformer not implemented yet"),
ModelSource::Ollama => {
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
Some(k) => k,
None => {
error!("failed to get Ollama url from GUC");
}
};

let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
model: transformer.name.to_string(),
};

EmbeddingRequest {
url,
payload: embedding_request,
api_key: None
}
}
};
let timeout = EMBEDDING_REQ_TIMEOUT_SEC.get();

match transformer.source {
ModelSource::Ollama => error!("Ollama transformer not implemented yet"),
ModelSource::Ollama => {
// Call the embeddings generation function
let embeddings = generate_embeddings(embedding_request);
match embeddings{
Ok(k) => k,
Err(e) => error!("error getting embeddings: {}", e)
}
},
ModelSource::OpenAI | ModelSource::SentenceTransformers => {
match runtime
.block_on(async { openai_embedding_request(embedding_request, timeout).await })
Expand Down
Loading