Skip to content

Commit

Permalink
feat: support vertexai (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Feb 15, 2024
1 parent 3bf0c37 commit 5e42109
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 5 deletions.
9 changes: 7 additions & 2 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ clients:

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_base: https://{RESOURCE}.openai.azure.com
api_key: xxx
models:
- name: MyGPT4 # Model deployment name
Expand All @@ -69,4 +69,9 @@ clients:

# See https://help.aliyun.com/zh/dashscope/
- type: qianwen
api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

# See https://cloud.google.com/vertex-ai
- type: vertexai
api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models
api_key: xxx
6 changes: 3 additions & 3 deletions src/client/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl GeminiClient {
}
}

async fn send_message(builder: RequestBuilder) -> Result<String> {
pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
Expand All @@ -102,7 +102,7 @@ async fn send_message(builder: RequestBuilder) -> Result<String> {
Ok(output.to_string())
}

async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
pub(crate) async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
let res = builder.send().await?;
if res.status() != 200 {
let data: Value = res.json().await?;
Expand Down Expand Up @@ -178,7 +178,7 @@ fn check_error(data: &Value) -> Result<()> {
}
}

fn build_body(data: SendData, _model: String) -> Result<Value> {
pub(crate) fn build_body(data: SendData, _model: String) -> Result<Value> {
let SendData {
mut messages,
temperature,
Expand Down
1 change: 1 addition & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ register_client!(
),
(ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
);
92 changes: 92 additions & 0 deletions src/client/vertexai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use super::{
Client, ExtraConfig, VertexAIClient, Model, PromptType,
SendData, TokensCountFactors,
};
use super::gemini::{build_body, send_message, send_message_streaming};

use crate::{render::ReplyHandler, utils::PromptKind};

use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;

const MODELS: [(&str, usize, &str); 2] = [
("gemini-pro", 32760, "text"),
("gemini-pro-vision", 16384, "text,vision"),
];

const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);

#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub extra: Option<ExtraConfig>,
}

#[async_trait]
impl Client for VertexAIClient {
client_common_fns!();

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
send_message(builder).await
}

async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
}
}

impl VertexAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);

pub const PROMPTS: [PromptType<'static>; 2] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
];

pub fn list_models(local_config: &VertexAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
}

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;

let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};

let body = build_body(data, self.model.name.clone())?;

let model = self.model.name.clone();

let url = format!("{api_base}/{}:{}", model, func);

debug!("VertexAI Request: {url} {body}");

let builder = client.post(url).bearer_auth(api_key).json(&body);

Ok(builder)
}
}

0 comments on commit 5e42109

Please sign in to comment.