Skip to content

Commit

Permalink
refactor: rename SendData to CompletionData (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed May 30, 2024
1 parent fa4bf14 commit 54a8377
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 99 deletions.
10 changes: 7 additions & 3 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
openai::*, AzureOpenAIClient, Client, ExtraConfig, Model, ModelData, ModelPatches,
PromptAction, PromptKind, SendData,
openai::*, AzureOpenAIClient, Client, CompletionData, ExtraConfig, Model, ModelData,
ModelPatches, PromptAction, PromptKind,
};

use anyhow::Result;
Expand Down Expand Up @@ -33,7 +33,11 @@ impl AzureOpenAIClient {
),
];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;

Expand Down
26 changes: 15 additions & 11 deletions src/client/bedrock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
prompt_format::*, claude::*,
catch_error, BedrockClient, Client, CompletionOutput, ExtraConfig, Model, ModelData,
ModelPatches, PromptAction, PromptKind, SendData, SseHandler,
catch_error, claude::*, prompt_format::*, BedrockClient, Client, CompletionData,
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
SseHandler,
};

use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
Expand Down Expand Up @@ -41,7 +41,7 @@ impl Client for BedrockClient {
async fn send_message_inner(
&self,
client: &ReqwestClient,
data: SendData,
data: CompletionData,
) -> Result<CompletionOutput> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
Expand All @@ -52,7 +52,7 @@ impl Client for BedrockClient {
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: SendData,
data: CompletionData,
) -> Result<()> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
Expand Down Expand Up @@ -84,7 +84,7 @@ impl BedrockClient {
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
data: CompletionData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
Expand Down Expand Up @@ -211,7 +211,11 @@ async fn send_message_streaming(
Ok(())
}

fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> Result<Value> {
fn build_body(
data: CompletionData,
model: &Model,
model_category: &ModelCategory,
) -> Result<Value> {
match model_category {
ModelCategory::Anthropic => {
let mut body = claude_build_body(data, model)?;
Expand All @@ -227,8 +231,8 @@ fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) ->
}
}

fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Result<Value> {
let SendData {
fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat) -> Result<Value> {
let CompletionData {
messages,
temperature,
top_p,
Expand All @@ -251,8 +255,8 @@ fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Res
Ok(body)
}

fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
fn mistral_build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
messages,
temperature,
top_p,
Expand Down
14 changes: 9 additions & 5 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, Client,
CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData,
ModelPatches, PromptAction, PromptKind, SendData, SseHandler, SseMmessage, ToolCall,
CompletionData, CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart,
Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage, ToolCall,
};

use anyhow::{bail, Context, Result};
Expand All @@ -27,7 +27,11 @@ impl ClaudeClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();

let mut body = claude_build_body(data, &self.model)?;
Expand Down Expand Up @@ -131,8 +135,8 @@ pub async fn claude_send_message_streaming(
sse_stream(builder, handle).await
}

pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
pub fn claude_build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
mut messages,
temperature,
top_p,
Expand Down
14 changes: 9 additions & 5 deletions src/client/cloudflare.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, Client, CloudflareClient, CompletionOutput, ExtraConfig, Model,
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, SseMmessage,
catch_error, sse_stream, Client, CloudflareClient, CompletionData, CompletionOutput,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage,
};

use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -30,7 +30,11 @@ impl CloudflareClient {
("api_key", "API Key:", true, PromptKind::String),
];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let account_id = self.get_account_id()?;
let api_key = self.get_api_key()?;

Expand Down Expand Up @@ -79,8 +83,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
sse_stream(builder, handle).await
}

fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
messages,
temperature,
top_p,
Expand Down
14 changes: 9 additions & 5 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
SendData, SseHandler, ToolCall,
CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SseHandler, ToolCall,
};

use anyhow::{bail, Result};
Expand All @@ -27,7 +27,11 @@ impl CohereClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;

let mut body = build_body(data, &self.model)?;
Expand Down Expand Up @@ -93,8 +97,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
Ok(())
}

fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
mut messages,
temperature,
top_p,
Expand Down
14 changes: 7 additions & 7 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ macro_rules! impl_client_trait {
async fn send_message_inner(
&self,
client: &reqwest::Client,
data: $crate::client::SendData,
data: $crate::client::CompletionData,
) -> anyhow::Result<$crate::client::CompletionOutput> {
let builder = self.request_builder(client, data)?;
$send_message(builder).await
Expand All @@ -213,7 +213,7 @@ macro_rules! impl_client_trait {
&self,
client: &reqwest::Client,
handler: &mut $crate::client::SseHandler,
data: $crate::client::SendData,
data: $crate::client::CompletionData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
$send_message_streaming(builder, handler).await
Expand Down Expand Up @@ -289,7 +289,7 @@ pub trait Client: Sync + Send {
}
let client = self.build_client()?;

let data = input.prepare_send_data(self.model(), false)?;
let data = input.prepare_completion_data(self.model(), false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
Expand Down Expand Up @@ -318,7 +318,7 @@ pub trait Client: Sync + Send {
return Ok(());
}
let client = self.build_client()?;
let data = input.prepare_send_data(self.model(), true)?;
let data = input.prepare_completion_data(self.model(), true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
Expand All @@ -343,14 +343,14 @@ pub trait Client: Sync + Send {
async fn send_message_inner(
&self,
client: &ReqwestClient,
data: SendData,
data: CompletionData,
) -> Result<CompletionOutput>;

async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: SendData,
data: CompletionData,
) -> Result<()>;
}

Expand Down Expand Up @@ -391,7 +391,7 @@ pub fn select_model_patch<'a>(
}

#[derive(Debug)]
pub struct SendData {
pub struct CompletionData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
Expand Down
21 changes: 12 additions & 9 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::access_token::*;
use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionOutput, ErnieClient,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler,
SseMmessage,
access_token::*, maybe_catch_error, patch_system_message, sse_stream, Client, CompletionData,
CompletionOutput, ErnieClient, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SseHandler, SseMmessage,
};

use anyhow::{anyhow, Context, Result};
Expand Down Expand Up @@ -32,7 +31,11 @@ impl ErnieClient {
("secret_key", "Secret Key:", true, PromptKind::String),
];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let mut body = build_body(data, &self.model);
self.patch_request_body(&mut body);

Expand Down Expand Up @@ -81,7 +84,7 @@ impl Client for ErnieClient {
async fn send_message_inner(
&self,
client: &ReqwestClient,
data: SendData,
data: CompletionData,
) -> Result<CompletionOutput> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
Expand All @@ -92,7 +95,7 @@ impl Client for ErnieClient {
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: SendData,
data: CompletionData,
) -> Result<()> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
Expand Down Expand Up @@ -120,8 +123,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
sse_stream(builder, handle).await
}

fn build_body(data: SendData, model: &Model) -> Value {
let SendData {
fn build_body(data: CompletionData, model: &Model) -> Value {
let CompletionData {
mut messages,
temperature,
top_p,
Expand Down
10 changes: 7 additions & 3 deletions src/client/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
vertexai::*, Client, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SendData,
vertexai::*, Client, CompletionData, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches,
PromptAction, PromptKind,
};

use anyhow::Result;
Expand All @@ -27,7 +27,11 @@ impl GeminiClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;

let func = match data.stream {
Expand Down
14 changes: 9 additions & 5 deletions src/client/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
catch_error, json_stream, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData,
ModelPatches, OllamaClient, PromptAction, PromptKind, SendData, SseHandler,
catch_error, json_stream, message::*, Client, CompletionData, CompletionOutput, ExtraConfig,
Model, ModelData, ModelPatches, OllamaClient, PromptAction, PromptKind, SseHandler,
};

use anyhow::{anyhow, bail, Result};
Expand Down Expand Up @@ -35,7 +35,11 @@ impl OllamaClient {
),
];

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
fn request_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_auth = self.get_api_auth().ok();

Expand Down Expand Up @@ -101,8 +105,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
Ok(())
}

fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
messages,
temperature,
top_p,
Expand Down
Loading

0 comments on commit 54a8377

Please sign in to comment.