diff --git a/README.md b/README.md index 9f08fb64..8bcfbca4 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Chat REPL mode: - Gemini (free, vision) - Claude: (paid) - Mistral (paid) + - Cohere (paid) - OpenAI-Compatible (local) - Ollama (free, local) - Azure-OpenAI (paid) diff --git a/config.example.yaml b/config.example.yaml index 52a1593e..8d6052d4 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -45,9 +45,14 @@ clients: - type: claude api_key: sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # See https://docs.mistral.ai/ - type: mistral api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # See https://docs.cohere.com/docs/the-cohere-platform + - type: cohere + api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # Any openai-compatible API providers - type: openai-compatible name: localai diff --git a/src/client/cohere.rs b/src/client/cohere.rs new file mode 100644 index 00000000..ca770352 --- /dev/null +++ b/src/client/cohere.rs @@ -0,0 +1,244 @@ +use super::{ + message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, PromptType, + SendData, TokensCountFactors, +}; + +use crate::{render::ReplyHandler, utils::PromptKind}; + +use anyhow::{bail, Result}; +use async_trait::async_trait; +use futures_util::StreamExt; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_URL: &str = "https://api.cohere.ai/v1/chat"; + +const MODELS: [(&str, usize, &str); 2] = [ + // https://docs.cohere.com/docs/command-r + ("command-r", 128000, "text"), + ("command-r-plus", 128000, "text"), +]; + +const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct CohereConfig { + pub name: Option, + pub api_key: Option, + pub extra: Option, +} + +#[async_trait] +impl Client for CohereClient { + client_common_fns!(); + + async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { + let builder = self.request_builder(client, data)?; + send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut ReplyHandler, + data: SendData, + ) -> Result<()> { + let builder = self.request_builder(client, data)?; + send_message_streaming(builder, handler).await + } +} + +impl CohereClient { + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptType<'static>; 1] = + [("api_key", "API Key:", false, PromptKind::String)]; + + pub fn list_models(local_config: &CohereConfig) -> Vec { + let client_name = Self::name(local_config); + MODELS + .into_iter() + .map(|(name, max_input_tokens, capabilities)| { + Model::new(client_name, name) + .set_capabilities(capabilities.into()) + .set_max_input_tokens(Some(max_input_tokens)) + .set_tokens_count_factors(TOKENS_COUNT_FACTORS) + }) + .collect() + } + + fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + let api_key = self.get_api_key().ok(); + + let mut body = build_body(data, self.model.name.clone())?; + self.model.merge_extra_fields(&mut body); + + let url = API_URL; + + debug!("Cohere Request: {url} {body}"); + + let mut builder = client.post(url).json(&body); + if let Some(api_key) = api_key { + builder = builder.bearer_auth(api_key); + } + + Ok(builder) + } +} + +pub(crate) async fn send_message(builder: RequestBuilder) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if status != 200 { + check_error(&data)?; + } + let output = extract_text(&data)?; + Ok(output.to_string()) +} + +pub(crate) async fn send_message_streaming( + builder: RequestBuilder, + handler: &mut ReplyHandler, +) -> Result<()> { + let res = builder.send().await?; + if res.status() != 200 { + let data: Value = res.json().await?; + check_error(&data)?; + } else { + let mut buffer = vec![]; + let mut cursor = 0; + let mut start = 0; + let mut balances = vec![]; + let mut quoting = false; + let mut stream = res.bytes_stream(); + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + let chunk = std::str::from_utf8(&chunk)?; + buffer.extend(chunk.chars()); + for i in cursor..buffer.len() { + let ch = buffer[i]; + if quoting { + if ch == '"' && buffer[i - 1] != '\\' { + quoting = false; + } + continue; + } + match ch { + '"' => quoting = true, + '{' => { + if balances.is_empty() { + start = i; + } + balances.push(ch); + } + '[' => { + if start != 0 { + balances.push(ch); + } + } + '}' => { + balances.pop(); + if balances.is_empty() { + let value: String = buffer[start..=i].iter().collect(); + let value: Value = serde_json::from_str(&value)?; + if let Some("text-generation") = value["event_type"].as_str() { + handler.text(extract_text(&value)?)?; + } + } + } + ']' => { + balances.pop(); + } + _ => {} + } + } + cursor = buffer.len(); + } + } + Ok(()) +} + +fn extract_text(data: &Value) -> Result<&str> { + match data["text"].as_str() { + Some(text) => Ok(text), + None => { + bail!("Invalid response data: {data}") + } + } +} + +fn check_error(data: &Value) -> Result<()> { + if let Some(message) = data["message"].as_str() { + bail!("{message}"); + } else { + bail!("Error {}", data); + } +} + +pub(crate) fn build_body(data: SendData, model: String) -> Result { + let SendData { + mut messages, + temperature, + stream, + } = data; + + patch_system_message(&mut messages); + + let mut image_urls = vec![]; + let mut messages: Vec = messages + .into_iter() + .map(|message| { + let role = match message.role { + MessageRole::User => "USER", + _ => "CHATBOT", + }; + match message.content { + MessageContent::Text(text) => json!({ + "role": role, + "message": text, + }), + MessageContent::Array(list) => { + let list: Vec = list + .into_iter() + .filter_map(|item| match item { + MessageContentPart::Text { text } => Some(text), + MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + } => { + image_urls.push(url.clone()); + None + } + }) + .collect(); + json!({ "role": role, "message": list.join("\n\n") }) + } + } + }) + .collect(); + + if !image_urls.is_empty() { + bail!("The model does not support images: {:?}", image_urls); + } + let message = messages.pop().unwrap(); + let message = message["message"].as_str().unwrap_or_default(); + + let mut body = json!({ + "model": model, + "message": message, + }); + + if !messages.is_empty() { + body["chat_history"] = messages.into(); + } + + if let Some(temperature) = temperature { + body["temperature"] = temperature.into(); + } + if stream { + body["stream"] = true.into(); + } + + Ok(body) +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 37775f35..d66b2822 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -12,6 +12,7 @@ register_client!( (gemini, "gemini", GeminiConfig, GeminiClient), (claude, "claude", ClaudeConfig, ClaudeClient), (mistral, "mistral", MistralConfig, MistralClient), + (cohere, "cohere", CohereConfig, CohereClient), ( openai_compatible, "openai-compatible",