Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Ollama support in transform function #106

Merged
merged 14 commits into from
Jun 24, 2024
52 changes: 52 additions & 0 deletions core/src/transformers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use anyhow::Result;
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
use url::Url;

use super::types::EmbeddingRequest;

pub struct OllamaInstance {
pub model_name: String,
pub instance: Ollama,
Expand All @@ -11,6 +13,8 @@ pub trait LLMFunctions {
fn new(model_name: String, url: String) -> Self;
#[allow(async_fn_in_trait)]
async fn generate_reponse(&self, prompt_text: String) -> Result<String, String>;
#[allow(async_fn_in_trait)]
async fn generate_embedding(&self, inputs: String) -> Result<Vec<f64>, String>;
}

impl LLMFunctions for OllamaInstance {
Expand Down Expand Up @@ -38,6 +42,16 @@ impl LLMFunctions for OllamaInstance {
Err(e) => Err(e.to_string()),
}
}
async fn generate_embedding(&self, input: String) -> Result<Vec<f64>, String> {
let embed = self
.instance
.generate_embeddings(self.model_name.clone(), input, None)
.await;
match embed {
Ok(res) => Ok(res.embeddings),
Err(e) => Err(e.to_string()),
}
}
}

pub fn ollama_embedding_dim(model_name: &str) -> i32 {
Expand All @@ -46,3 +60,41 @@ pub fn ollama_embedding_dim(model_name: &str) -> i32 {
_ => 1536,
}
}

pub fn check_model_host(url: &str) -> Result<String, String> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e));

runtime.block_on(async {
let response = reqwest::get(url).await.unwrap();
match response.status() {
reqwest::StatusCode::OK => Ok(format!("Success! {:?}", response)),
_ => Err(format!("Error! {:?}", response)),
}
})
}

pub fn generate_embeddings(request: EmbeddingRequest) -> Result<Vec<Vec<f64>>> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e));

runtime.block_on(async {
let instance = OllamaInstance::new(request.payload.model, request.url);
let mut embeddings: Vec<Vec<f64>> = vec![];
for input in request.payload.input {
let response = instance.generate_embedding(input).await;
let embedding = match response {
Ok(embed) => embed,
Err(e) => panic!("Unable to generate embeddings.\nError: {:?}", e),
};
embeddings.push(embedding);
}
Ok(embeddings)
})
}
3 changes: 3 additions & 0 deletions core/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use chrono::serde::ts_seconds_option::deserialize as from_tsopt;

use serde::{Deserialize, Serialize};
use sqlx::types::chrono::Utc;
use sqlx::FromRow;
Expand Down Expand Up @@ -168,10 +169,12 @@ pub enum ModelError {
impl Model {
pub fn new(input: &str) -> Result<Self, ModelError> {
let mut parts: Vec<&str> = input.split('/').collect();

let missing_source = parts.len() < 2;
if parts.len() > 3 {
return Err(ModelError::InvalidFormat(input.to_string()));
}

if missing_source && parts[0] == "text-embedding-ada-002" {
// for backwards compatibility, prepend "openai" to text-embedding-ada-2
parts.insert(0, "openai");
Expand Down
3 changes: 1 addition & 2 deletions extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ install-pgvector:

install-pgmq:
git clone https://github.com/tembo-io/pgmq.git && \
cd pgmq && \
PG_CONFIG=${PGRX_PG_CONFIG} make clean && \
cd pgmq/pgmq-extension && \
PG_CONFIG=${PGRX_PG_CONFIG} make && \
PG_CONFIG=${PGRX_PG_CONFIG} make install && \
cd .. && rm -rf pgmq
Expand Down
20 changes: 19 additions & 1 deletion extension/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::util;

use anyhow::{Context, Result};
use pgrx::prelude::*;
use vectorize_core::transformers::ollama::check_model_host;
use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta};

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -69,9 +70,26 @@ pub fn init_table(
sync_get_model_info(&transformer.fullname, api_key.clone())
.context("transformer does not exist")?;
}
ModelSource::Ollama | ModelSource::Tembo => {
ModelSource::Tembo => {
error!("Ollama/Tembo not implemented for search yet");
}
ModelSource::Ollama => {
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
Some(k) => k,
None => {
error!("failed to get Ollama url from GUC");
}
};
let res = check_model_host(&url);
match res {
Ok(_) => {
info!("Model host active!")
}
Err(e) => {
error!("Error with model host: {:?}", e)
}
}
}
}

let valid_params = types::JobParams {
Expand Down
35 changes: 32 additions & 3 deletions extension/src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use generic::get_env_interpolated_guc;
use pgrx::prelude::*;

use vectorize_core::transformers::http_handler::openai_embedding_request;
use vectorize_core::transformers::ollama::generate_embeddings;
use vectorize_core::transformers::openai::OPENAI_BASE_URL;
use vectorize_core::transformers::types::{EmbeddingPayload, EmbeddingRequest};
use vectorize_core::types::{Model, ModelSource};
Expand Down Expand Up @@ -61,14 +62,38 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
api_key: api_key.map(|s| s.to_string()),
}
}
ModelSource::Ollama => error!("Ollama transformer not implemented yet"),
ModelSource::Ollama => {
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
Some(k) => k,
None => {
error!("failed to get Ollama url from GUC");
}
};

let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
model: transformer.name.to_string(),
};

EmbeddingRequest {
url,
payload: embedding_request,
api_key: None,
}
}
};
let timeout = EMBEDDING_REQ_TIMEOUT_SEC.get();

match transformer.source {
ModelSource::Ollama | ModelSource::Tembo => {
error!("Ollama/Tembo transformer not implemented yet")
ModelSource::Ollama => {
// Call the embeddings generation function
let embeddings = generate_embeddings(embedding_request);
match embeddings {
Ok(k) => k,
Err(e) => error!("error getting embeddings: {}", e),
}
}

ModelSource::OpenAI | ModelSource::SentenceTransformers => {
match runtime
.block_on(async { openai_embedding_request(embedding_request, timeout).await })
Expand All @@ -79,5 +104,9 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
}
}
}

ModelSource::Tembo => {
error!("Embeddings support not added for Tembo yet!")
}
}
}
Loading