Skip to content

Commit

Permalink
feat: support rag-dedicated clients (jina and voyageai) (#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Jun 24, 2024
1 parent ed71901 commit 2fbb527
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 47 deletions.
12 changes: 12 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,15 @@ clients:
name: together
api_base: https://api.together.xyz/v1
api_key: xxx # ENV: {client}_API_KEY

# See https://jina.ai
- type: rag-dedicated
name: jina
api_base: https://api.jina.ai/v1
api_key: xxx # ENV: {client}_API_KEY

# See https://docs.voyageai.com/docs/introduction
- type: rag-dedicated
name: voyageai
api_base: https://api.voyageai.ai/v1
api_key: xxx # ENV: {client}_API_KEY
109 changes: 108 additions & 1 deletion models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1145,4 +1145,111 @@
input_price: 0.008
output_vector_size: 768
default_chunk_size: 1000
max_batch_size: 100
max_batch_size: 100

- platform: jina
# docs:
# - https://jina.ai/
# - https://api.jina.ai/redoc
models:
- name: jina-embeddings-v2-base-en
type: embedding
max_input_tokens: 8192
input_price: 0.02
output_vector_size: 768
default_chunk_size: 1500
max_batch_size: 100
- name: jina-embeddings-v2-small-en
type: embedding
max_input_tokens: 8192
input_price: 0.02
output_vector_size: 512
default_chunk_size: 1000
max_batch_size: 100
- name: jina-embeddings-v2-base-zsh
type: embedding
max_input_tokens: 8192
input_price: 0.02
output_vector_size: 768
default_chunk_size: 1500
max_batch_size: 100
- name: jina-embeddings-v2-base-code
type: embedding
max_input_tokens: 8192
input_price: 0.02
output_vector_size: 768
default_chunk_size: 1500
max_batch_size: 100
- name: jina-colbert-v1-en
type: embedding
max_input_tokens: 8192
input_price: 0.02
output_vector_size: 768
default_chunk_size: 1500
max_batch_size: 100
- name: jina-reranker-v1-base-en
type: rerank
max_input_tokens: 8192
input_price: 0.02
- name: jina-reranker-v1-turbo-en
type: rerank
max_input_tokens: 8192
input_price: 0.02
- name: jina-colbert-v1-en
type: rerank
max_input_tokens: 8192
input_price: 0.02
- name: jina-reranker-v1-base-multilingual
type: rerank
max_input_tokens: 8192
input_price: 0.02

- platform: voyageai
# docs:
# - https://docs.voyageai.com/docs/embeddings
# - https://docs.voyageai.com/docs/pricing
# - https://docs.voyageai.com/reference/embeddings-api
models:
- name: voyage-large-2-instruct
type: embedding
max_input_tokens: 16000
input_price: 0.12
output_vector_size: 1024
default_chunk_size: 2000
max_batch_size: 128
- name: voyage-large-2
type: embedding
max_input_tokens: 16000
input_price: 0.12
output_vector_size: 1536
default_chunk_size: 3000
max_batch_size: 128
- name: voyage-multilingual-2
type: embedding
max_input_tokens: 32000
input_price: 0.12
output_vector_size: 1024
default_chunk_size: 2000
max_batch_size: 128
- name: voyage-code-2
type: embedding
max_input_tokens: 16000
input_price: 0.12
output_vector_size: 1536
default_chunk_size: 3000
max_batch_size: 128
- name: voyage-2
type: embedding
max_input_tokens: 4000
input_price: 0.1
output_vector_size: 1024
default_chunk_size: 2000
max_batch_size: 128
- name: rerank-1
type: rerank
max_input_tokens: 8000
input_price: 0.05
- name: rerank-lite-1
type: rerank
max_input_tokens: 4000
input_price: 0.02
36 changes: 3 additions & 33 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::rag_dedicated::*;
use super::*;

use anyhow::{bail, Context, Result};
Expand Down Expand Up @@ -74,7 +75,7 @@ impl CohereClient {
fn rerank_builder(&self, client: &ReqwestClient, data: RerankData) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;

let body = cohere_build_rerank_body(data, &self.model);
let body = rag_dedicated_build_rerank_body(data, &self.model);

let url = RERANK_API_URL;

Expand All @@ -91,7 +92,7 @@ impl_client_trait!(
chat_completions,
chat_completions_streaming,
embeddings,
cohere_rerank
rag_dedicated_rerank
);

async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
Expand Down Expand Up @@ -162,22 +163,6 @@ struct EmbeddingsResBody {
embeddings: Vec<Vec<f32>>,
}

pub async fn cohere_rerank(builder: RequestBuilder) -> Result<RerankOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: RerankResBody = serde_json::from_value(data).context("Invalid rerank data")?;
Ok(res_body.results)
}

#[derive(Deserialize)]
struct RerankResBody {
results: RerankOutput,
}

fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
mut messages,
Expand Down Expand Up @@ -309,21 +294,6 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
Ok(body)
}

pub fn cohere_build_rerank_body(data: RerankData, model: &Model) -> Value {
let RerankData {
query,
documents,
top_n,
} = data;

json!({
"model": model.name(),
"query": query,
"documents": documents,
"top_n": top_n
})
}

fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["text"].as_str().unwrap_or_default();

Expand Down
10 changes: 6 additions & 4 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ macro_rules! register_client {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
if let Some(models) = $crate::client::ALL_MODELS.iter().find(|v| {
v.platform == $name || ($name == "openai-compatible" && local_config.name.as_deref() == Some(&v.platform))
v.platform == $name ||
($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform)) ||
($name == RagDedicatedClient::NAME && local_config.name.as_deref() == Some(&v.platform))
}) {
return Model::from_config(client_name, &models.models);
}
Expand Down Expand Up @@ -432,15 +434,15 @@ pub trait Client: Sync + Send {
_client: &ReqwestClient,
_data: EmbeddingsData,
) -> Result<EmbeddingsOutput> {
bail!("No embeddings api")
bail!("The client doesn't support embeddings api")
}

async fn rerank_inner(
&self,
_client: &ReqwestClient,
_data: RerankData,
) -> Result<RerankOutput> {
bail!("No rerank api")
bail!("The client doesn't support rerank api")
}
}

Expand Down Expand Up @@ -566,7 +568,7 @@ pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(St
None => Ok(None),
Some((name, api_base)) => {
let mut config = json!({
"type": "openai-compatible",
"type": OpenAICompatibleClient::NAME,
"name": name,
"api_base": api_base,
});
Expand Down
9 changes: 3 additions & 6 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::access_token::*;
use super::rag_dedicated::*;
use super::*;

use anyhow::{anyhow, bail, Context, Result};
Expand Down Expand Up @@ -220,15 +221,11 @@ struct EmbeddingsResBodyEmbedding {
async fn rerank(builder: RequestBuilder) -> Result<RerankOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
let res_body: RerankResBody = serde_json::from_value(data).context("Invalid rerank data")?;
let res_body: RagDedicatedRerankResBody =
serde_json::from_value(data).context("Invalid rerank data")?;
Ok(res_body.results)
}

#[derive(Deserialize)]
struct RerankResBody {
results: RerankOutput,
}

fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
mut messages,
Expand Down
11 changes: 11 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ register_client!(
OpenAICompatibleConfig,
OpenAICompatibleClient
),
(
rag_dedicated,
"rag-dedicated",
RagDedicatedConfig,
RagDedicatedClient
),
(gemini, "gemini", GeminiConfig, GeminiClient),
(claude, "claude", ClaudeConfig, ClaudeClient),
(cohere, "cohere", CohereConfig, CohereClient),
Expand Down Expand Up @@ -60,3 +66,8 @@ pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 13] = [
("zhipuai", "https://open.bigmodel.cn/api/paas/v4"),
("lingyiwanwu", "https://api.lingyiwanwu.com/v1"),
];

pub const RAG_DEDICATED_PLATFORMS: [(&str, &str); 2] = [
("jina", "https://api.jina.ai/v1"),
("voyageai", "https://api.voyageai.com/v1"),
];
6 changes: 3 additions & 3 deletions src/client/openai_compatible.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::cohere::*;
use super::openai::*;
use super::rag_dedicated::*;
use super::*;

use anyhow::Result;
Expand Down Expand Up @@ -90,7 +90,7 @@ impl OpenAICompatibleClient {
let api_key = self.get_api_key().ok();
let api_base = self.get_api_base_ext()?;

let body = cohere_build_rerank_body(data, &self.model);
let body = rag_dedicated_build_rerank_body(data, &self.model);

let url = format!("{api_base}/rerank");

Expand Down Expand Up @@ -131,5 +131,5 @@ impl_client_trait!(
openai_chat_completions,
openai_chat_completions_streaming,
openai_embeddings,
cohere_rerank
rag_dedicated_rerank
);
Loading

0 comments on commit 2fbb527

Please sign in to comment.