From 54a837784c3df7e73c7c95adfef96e68c842984b Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 30 May 2024 17:10:02 +0800 Subject: [PATCH] refactor: rename `SendData` to `CompletionData` (#553) --- src/client/azure_openai.rs | 10 +++++++--- src/client/bedrock.rs | 26 +++++++++++++++----------- src/client/claude.rs | 14 +++++++++----- src/client/cloudflare.rs | 14 +++++++++----- src/client/cohere.rs | 14 +++++++++----- src/client/common.rs | 14 +++++++------- src/client/ernie.rs | 21 ++++++++++++--------- src/client/gemini.rs | 10 +++++++--- src/client/ollama.rs | 14 +++++++++----- src/client/openai.rs | 16 ++++++++++------ src/client/openai_compatible.rs | 10 +++++++--- src/client/qianwen.rs | 20 ++++++++++++-------- src/client/replicate.rs | 16 ++++++++-------- src/client/vertexai.rs | 18 +++++++++++------- src/client/vertexai_claude.rs | 14 +++++++++----- src/config/input.rs | 8 ++++---- src/serve.rs | 12 +++++++----- 17 files changed, 152 insertions(+), 99 deletions(-) diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 3779039c..75f8455b 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,6 +1,6 @@ use super::{ - openai::*, AzureOpenAIClient, Client, ExtraConfig, Model, ModelData, ModelPatches, - PromptAction, PromptKind, SendData, + openai::*, AzureOpenAIClient, Client, CompletionData, ExtraConfig, Model, ModelData, + ModelPatches, PromptAction, PromptKind, }; use anyhow::Result; @@ -33,7 +33,11 @@ impl AzureOpenAIClient { ), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_base = self.get_api_base()?; let api_key = self.get_api_key()?; diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 1f13c7fb..355565f5 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -1,7 +1,7 @@ use super::{ - prompt_format::*, claude::*, - catch_error, BedrockClient, Client, CompletionOutput, ExtraConfig, Model, ModelData, - ModelPatches, PromptAction, PromptKind, SendData, SseHandler, + catch_error, claude::*, prompt_format::*, BedrockClient, Client, CompletionData, + CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, + SseHandler, }; use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256}; @@ -41,7 +41,7 @@ impl Client for BedrockClient { async fn send_message_inner( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, ) -> Result { let model_category = ModelCategory::from_str(self.model.name())?; let builder = self.request_builder(client, data, &model_category)?; @@ -52,7 +52,7 @@ impl Client for BedrockClient { &self, client: &ReqwestClient, handler: &mut SseHandler, - data: SendData, + data: CompletionData, ) -> Result<()> { let model_category = ModelCategory::from_str(self.model.name())?; let builder = self.request_builder(client, data, &model_category)?; @@ -84,7 +84,7 @@ impl BedrockClient { fn request_builder( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, model_category: &ModelCategory, ) -> Result { let access_key_id = self.get_access_key_id()?; @@ -211,7 +211,11 @@ async fn send_message_streaming( Ok(()) } -fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> Result { +fn build_body( + data: CompletionData, + model: &Model, + model_category: &ModelCategory, +) -> Result { match model_category { ModelCategory::Anthropic => { let mut body = claude_build_body(data, model)?; @@ -227,8 +231,8 @@ fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> } } -fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Result { - let SendData { +fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat) -> Result { + let CompletionData { messages, temperature, top_p, @@ -251,8 +255,8 @@ fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Res Ok(body) } -fn mistral_build_body(data: SendData, model: &Model) -> Result { - let SendData { +fn mistral_build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { messages, temperature, top_p, diff --git a/src/client/claude.rs b/src/client/claude.rs index ccedb51d..3194fe23 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,7 +1,7 @@ use super::{ catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, Client, - CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData, - ModelPatches, PromptAction, PromptKind, SendData, SseHandler, SseMmessage, ToolCall, + CompletionData, CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart, + Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage, ToolCall, }; use anyhow::{bail, Context, Result}; @@ -27,7 +27,11 @@ impl ClaudeClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_key = self.get_api_key().ok(); let mut body = claude_build_body(data, &self.model)?; @@ -131,8 +135,8 @@ pub async fn claude_send_message_streaming( sse_stream(builder, handle).await } -pub fn claude_build_body(data: SendData, model: &Model) -> Result { - let SendData { +pub fn claude_build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { mut messages, temperature, top_p, diff --git a/src/client/cloudflare.rs b/src/client/cloudflare.rs index 659c7c06..15dd5bb4 100644 --- a/src/client/cloudflare.rs +++ b/src/client/cloudflare.rs @@ -1,6 +1,6 @@ use super::{ - catch_error, sse_stream, Client, CloudflareClient, CompletionOutput, ExtraConfig, Model, - ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, SseMmessage, + catch_error, sse_stream, Client, CloudflareClient, CompletionData, CompletionOutput, + ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage, }; use anyhow::{anyhow, Result}; @@ -30,7 +30,11 @@ impl CloudflareClient { ("api_key", "API Key:", true, PromptKind::String), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let account_id = self.get_account_id()?; let api_key = self.get_api_key()?; @@ -79,8 +83,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle sse_stream(builder, handle).await } -fn build_body(data: SendData, model: &Model) -> Result { - let SendData { +fn build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { messages, temperature, top_p, diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 459c64ae..e093778a 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,7 +1,7 @@ use super::{ catch_error, extract_system_message, json_stream, message::*, Client, CohereClient, - CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, - SendData, SseHandler, ToolCall, + CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, + PromptKind, SseHandler, ToolCall, }; use anyhow::{bail, Result}; @@ -27,7 +27,11 @@ impl CohereClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_key = self.get_api_key()?; let mut body = build_body(data, &self.model)?; @@ -93,8 +97,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle Ok(()) } -fn build_body(data: SendData, model: &Model) -> Result { - let SendData { +fn build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { mut messages, temperature, top_p, diff --git a/src/client/common.rs b/src/client/common.rs index 044ed2e8..336fa784 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -203,7 +203,7 @@ macro_rules! impl_client_trait { async fn send_message_inner( &self, client: &reqwest::Client, - data: $crate::client::SendData, + data: $crate::client::CompletionData, ) -> anyhow::Result<$crate::client::CompletionOutput> { let builder = self.request_builder(client, data)?; $send_message(builder).await @@ -213,7 +213,7 @@ macro_rules! impl_client_trait { &self, client: &reqwest::Client, handler: &mut $crate::client::SseHandler, - data: $crate::client::SendData, + data: $crate::client::CompletionData, ) -> Result<()> { let builder = self.request_builder(client, data)?; $send_message_streaming(builder, handler).await @@ -289,7 +289,7 @@ pub trait Client: Sync + Send { } let client = self.build_client()?; - let data = input.prepare_send_data(self.model(), false)?; + let data = input.prepare_completion_data(self.model(), false)?; self.send_message_inner(&client, data) .await .with_context(|| "Failed to get answer") @@ -318,7 +318,7 @@ pub trait Client: Sync + Send { return Ok(()); } let client = self.build_client()?; - let data = input.prepare_send_data(self.model(), true)?; + let data = input.prepare_completion_data(self.model(), true)?; self.send_message_streaming_inner(&client, handler, data).await } => { handler.done()?; @@ -343,14 +343,14 @@ pub trait Client: Sync + Send { async fn send_message_inner( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, ) -> Result; async fn send_message_streaming_inner( &self, client: &ReqwestClient, handler: &mut SseHandler, - data: SendData, + data: CompletionData, ) -> Result<()>; } @@ -391,7 +391,7 @@ pub fn select_model_patch<'a>( } #[derive(Debug)] -pub struct SendData { +pub struct CompletionData { pub messages: Vec, pub temperature: Option, pub top_p: Option, diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 68c4224f..49e158bf 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,8 +1,7 @@ -use super::access_token::*; use super::{ - maybe_catch_error, patch_system_message, sse_stream, Client, CompletionOutput, ErnieClient, - ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, - SseMmessage, + access_token::*, maybe_catch_error, patch_system_message, sse_stream, Client, CompletionData, + CompletionOutput, ErnieClient, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, + PromptKind, SseHandler, SseMmessage, }; use anyhow::{anyhow, Context, Result}; @@ -32,7 +31,11 @@ impl ErnieClient { ("secret_key", "Secret Key:", true, PromptKind::String), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let mut body = build_body(data, &self.model); self.patch_request_body(&mut body); @@ -81,7 +84,7 @@ impl Client for ErnieClient { async fn send_message_inner( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, ) -> Result { self.prepare_access_token().await?; let builder = self.request_builder(client, data)?; @@ -92,7 +95,7 @@ impl Client for ErnieClient { &self, client: &ReqwestClient, handler: &mut SseHandler, - data: SendData, + data: CompletionData, ) -> Result<()> { self.prepare_access_token().await?; let builder = self.request_builder(client, data)?; @@ -120,8 +123,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle sse_stream(builder, handle).await } -fn build_body(data: SendData, model: &Model) -> Value { - let SendData { +fn build_body(data: CompletionData, model: &Model) -> Value { + let CompletionData { mut messages, temperature, top_p, diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 0d82a48c..05614ce8 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,6 +1,6 @@ use super::{ - vertexai::*, Client, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches, PromptAction, - PromptKind, SendData, + vertexai::*, Client, CompletionData, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches, + PromptAction, PromptKind, }; use anyhow::Result; @@ -27,7 +27,11 @@ impl GeminiClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_key = self.get_api_key()?; let func = match data.stream { diff --git a/src/client/ollama.rs b/src/client/ollama.rs index a650f40f..014824bc 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,6 +1,6 @@ use super::{ - catch_error, json_stream, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData, - ModelPatches, OllamaClient, PromptAction, PromptKind, SendData, SseHandler, + catch_error, json_stream, message::*, Client, CompletionData, CompletionOutput, ExtraConfig, + Model, ModelData, ModelPatches, OllamaClient, PromptAction, PromptKind, SseHandler, }; use anyhow::{anyhow, bail, Result}; @@ -35,7 +35,11 @@ impl OllamaClient { ), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_base = self.get_api_base()?; let api_auth = self.get_api_auth().ok(); @@ -101,8 +105,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle Ok(()) } -fn build_body(data: SendData, model: &Model) -> Result { - let SendData { +fn build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { messages, temperature, top_p, diff --git a/src/client/openai.rs b/src/client/openai.rs index 881365ff..bf719aec 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,7 +1,7 @@ use super::{ - catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model, ModelData, - ModelPatches, OpenAIClient, PromptAction, PromptKind, SendData, SseHandler, SseMmessage, - ToolCall, + catch_error, message::*, sse_stream, Client, CompletionData, CompletionOutput, ExtraConfig, + Model, ModelData, ModelPatches, OpenAIClient, PromptAction, PromptKind, SseHandler, + SseMmessage, ToolCall, }; use anyhow::{bail, Result}; @@ -30,7 +30,11 @@ impl OpenAIClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_key = self.get_api_key()?; let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string()); @@ -121,8 +125,8 @@ pub async fn openai_send_message_streaming( sse_stream(builder, handle).await } -pub fn openai_build_body(data: SendData, model: &Model) -> Value { - let SendData { +pub fn openai_build_body(data: CompletionData, model: &Model) -> Value { + let CompletionData { messages, temperature, top_p, diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs index fd101ae7..f6c0a85e 100644 --- a/src/client/openai_compatible.rs +++ b/src/client/openai_compatible.rs @@ -1,6 +1,6 @@ use super::{ - openai::*, Client, ExtraConfig, Model, ModelData, ModelPatches, OpenAICompatibleClient, - PromptAction, PromptKind, SendData, OPENAI_COMPATIBLE_PLATFORMS, + openai::*, Client, CompletionData, ExtraConfig, Model, ModelData, ModelPatches, + OpenAICompatibleClient, PromptAction, PromptKind, OPENAI_COMPATIBLE_PLATFORMS, }; use anyhow::Result; @@ -36,7 +36,11 @@ impl OpenAICompatibleClient { ), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_base = match self.get_api_base() { Ok(v) => v, Err(err) => { diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 2063f203..3f4b73aa 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,7 +1,7 @@ use super::{ - maybe_catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model, - ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient, SendData, SseHandler, - SseMmessage, + maybe_catch_error, message::*, sse_stream, Client, CompletionData, CompletionOutput, + ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient, + SseHandler, SseMmessage, }; use crate::utils::{base64_decode, sha256}; @@ -38,7 +38,11 @@ impl QianwenClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let api_key = self.get_api_key()?; let stream = data.stream; @@ -71,7 +75,7 @@ impl Client for QianwenClient { async fn send_message_inner( &self, client: &ReqwestClient, - mut data: SendData, + mut data: CompletionData, ) -> Result { let api_key = self.get_api_key()?; patch_messages(self.model.name(), &api_key, &mut data.messages).await?; @@ -83,7 +87,7 @@ impl Client for QianwenClient { &self, client: &ReqwestClient, handler: &mut SseHandler, - mut data: SendData, + mut data: CompletionData, ) -> Result<()> { let api_key = self.get_api_key()?; patch_messages(self.model.name(), &api_key, &mut data.messages).await?; @@ -129,8 +133,8 @@ async fn send_message_streaming( sse_stream(builder, handle).await } -fn build_body(data: SendData, model: &Model) -> Result<(Value, bool)> { - let SendData { +fn build_body(data: CompletionData, model: &Model) -> Result<(Value, bool)> { + let CompletionData { messages, temperature, top_p, diff --git a/src/client/replicate.rs b/src/client/replicate.rs index c39cb07e..3b0787a5 100644 --- a/src/client/replicate.rs +++ b/src/client/replicate.rs @@ -1,7 +1,7 @@ use super::{ - catch_error, prompt_format::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model, - ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient, SendData, SseHandler, - SseMmessage, + catch_error, prompt_format::*, sse_stream, Client, CompletionData, CompletionOutput, + ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient, + SseHandler, SseMmessage, }; use anyhow::{anyhow, Result}; @@ -32,7 +32,7 @@ impl ReplicateClient { fn request_builder( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, api_key: &str, ) -> Result { let mut body = build_body(data, &self.model)?; @@ -55,7 +55,7 @@ impl Client for ReplicateClient { async fn send_message_inner( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, ) -> Result { let api_key = self.get_api_key()?; let builder = self.request_builder(client, data, &api_key)?; @@ -66,7 +66,7 @@ impl Client for ReplicateClient { &self, client: &ReqwestClient, handler: &mut SseHandler, - data: SendData, + data: CompletionData, ) -> Result<()> { let api_key = self.get_api_key()?; let builder = self.request_builder(client, data, &api_key)?; @@ -135,8 +135,8 @@ async fn send_message_streaming( sse_stream(sse_builder, handle).await } -fn build_body(data: SendData, model: &Model) -> Result { - let SendData { +fn build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { messages, temperature, top_p, diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 1abbffce..102b5cbb 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,7 +1,7 @@ use super::{ access_token::*, catch_error, json_stream, message::*, patch_system_message, Client, - CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, - SendData, SseHandler, ToolCall, VertexAIClient, + CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, + PromptKind, SseHandler, ToolCall, VertexAIClient, }; use anyhow::{anyhow, bail, Context, Result}; @@ -35,7 +35,11 @@ impl VertexAIClient { ("location", "Location", true, PromptKind::String), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let project_id = self.get_project_id()?; let location = self.get_location()?; let access_token = get_access_token(self.name())?; @@ -66,7 +70,7 @@ impl Client for VertexAIClient { async fn send_message_inner( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, ) -> Result { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; @@ -77,7 +81,7 @@ impl Client for VertexAIClient { &self, client: &ReqwestClient, handler: &mut SseHandler, - data: SendData, + data: CompletionData, ) -> Result<()> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; @@ -177,8 +181,8 @@ fn gemini_extract_completion_text(data: &Value) -> Result { Ok(output) } -pub(crate) fn gemini_build_body(data: SendData, model: &Model) -> Result { - let SendData { +pub(crate) fn gemini_build_body(data: CompletionData, model: &Model) -> Result { + let CompletionData { mut messages, temperature, top_p, diff --git a/src/client/vertexai_claude.rs b/src/client/vertexai_claude.rs index 6420611a..fd065186 100644 --- a/src/client/vertexai_claude.rs +++ b/src/client/vertexai_claude.rs @@ -1,6 +1,6 @@ use super::{ - access_token::*, claude::*, vertexai::*, Client, CompletionOutput, ExtraConfig, Model, - ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, VertexAIClaudeClient, + access_token::*, claude::*, vertexai::*, Client, CompletionData, CompletionOutput, ExtraConfig, + Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, VertexAIClaudeClient, }; use anyhow::Result; @@ -29,7 +29,11 @@ impl VertexAIClaudeClient { ("location", "Location", true, PromptKind::String), ]; - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + fn request_builder( + &self, + client: &ReqwestClient, + data: CompletionData, + ) -> Result { let project_id = self.get_project_id()?; let location = self.get_location()?; let access_token = get_access_token(self.name())?; @@ -62,7 +66,7 @@ impl Client for VertexAIClaudeClient { async fn send_message_inner( &self, client: &ReqwestClient, - data: SendData, + data: CompletionData, ) -> Result { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; @@ -73,7 +77,7 @@ impl Client for VertexAIClaudeClient { &self, client: &ReqwestClient, handler: &mut SseHandler, - data: SendData, + data: CompletionData, ) -> Result<()> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let builder = self.request_builder(client, data)?; diff --git a/src/config/input.rs b/src/config/input.rs index fd42bf8b..44bbb4bb 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -1,8 +1,8 @@ use super::{role::Role, session::Session, GlobalConfig}; use crate::client::{ - init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart, - MessageRole, Model, SendData, + init_client, list_models, Client, CompletionData, ImageUrl, Message, MessageContent, + MessageContentPart, MessageRole, Model, }; use crate::function::{ToolCallResult, ToolResults}; use crate::utils::{base64_encode, sha256}; @@ -149,7 +149,7 @@ impl Input { init_client(&self.config, Some(self.model())) } - pub fn prepare_send_data(&self, model: &Model, stream: bool) -> Result { + pub fn prepare_completion_data(&self, model: &Model, stream: bool) -> Result { if !self.medias.is_empty() && !model.supports_vision() { bail!("The current model does not support vision."); } @@ -176,7 +176,7 @@ impl Input { }; functions = config.function.select(function_matcher); }; - Ok(SendData { + Ok(CompletionData { messages, temperature, top_p, diff --git a/src/serve.rs b/src/serve.rs index 0e7bda3d..5f43d88c 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -1,7 +1,7 @@ use crate::{ client::{ - init_client, list_models, ClientConfig, CompletionOutput, Message, Model, ModelData, - SendData, SseEvent, SseHandler, + init_client, list_models, ClientConfig, CompletionData, CompletionOutput, Message, Model, + ModelData, SseEvent, SseHandler, }, config::{Config, GlobalConfig, Role}, utils::create_abort_signal, @@ -270,7 +270,7 @@ impl Server { let completion_id = generate_completion_id(); let created = Utc::now().timestamp(); - let send_data: SendData = SendData { + let completion_data: CompletionData = CompletionData { messages, temperature, top_p, @@ -306,7 +306,7 @@ impl Server { } tokio::select! { _ = map_event(rx2, &tx, &mut is_first) => {} - ret = client.send_message_streaming_inner(&http_client, &mut handler, send_data) => { + ret = client.send_message_streaming_inner(&http_client, &mut handler, completion_data) => { if let Err(err) = ret { send_first_event(&tx, Some(format!("{err:?}")), &mut is_first) } @@ -350,7 +350,9 @@ impl Server { .body(BodyExt::boxed(StreamBody::new(stream)))?; Ok(res) } else { - let output = client.send_message_inner(&http_client, send_data).await?; + let output = client + .send_message_inner(&http_client, completion_data) + .await?; let res = Response::builder() .header("Content-Type", "application/json") .body(