Skip to content

Commit

Permalink
refactor: remove Model.client_index, match client by name (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Nov 7, 2023
1 parent 87aec71 commit 9a8b302
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 41 deletions.
4 changes: 2 additions & 2 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ impl AzureOpenAIClient {
),
];

pub fn list_models(local_config: &AzureOpenAIConfig, client_index: usize) -> Vec<Model> {
pub fn list_models(local_config: &AzureOpenAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);

local_config
.models
.iter()
.map(|v| {
Model::new(client_index, client_name, &v.name)
Model::new(client_name, &v.name)
.set_max_tokens(v.max_tokens)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
Expand Down
29 changes: 13 additions & 16 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,24 @@ macro_rules! register_client {

pub fn init(global_config: &$crate::config::GlobalConfig) -> Option<Box<dyn Client>> {
let model = global_config.read().model.clone();
let config = {
if let ClientConfig::$config(c) = &global_config.read().clients[model.client_index] {
c.clone()
} else {
return None;
let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == &model.client_name {
return Some(c.clone())
}
}
};
None
})?;

Some(Box::new(Self {
global_config: global_config.clone(),
config,
model,
}))
}

pub fn name(local_config: &$config) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
pub fn name(config: &$config) -> &str {
config.name.as_deref().unwrap_or(Self::NAME)
}
}

Expand All @@ -80,11 +82,7 @@ macro_rules! register_client {
$(.or_else(|| $client::init(config)))+
.ok_or_else(|| {
let model = config.read().model.clone();
anyhow::anyhow!(
"Unknown client '{}' at config.clients[{}]",
&model.client_name,
&model.client_index
)
anyhow::anyhow!("Unknown client '{}'", &model.client_name)
})
}

Expand All @@ -105,9 +103,8 @@ macro_rules! register_client {
config
.clients
.iter()
.enumerate()
.flat_map(|(i, v)| match v {
$(ClientConfig::$config(c) => $client::list_models(c, i),)+
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
Expand Down
6 changes: 3 additions & 3 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ impl ErnieClient {
("secret_key", "Secret Key:", true, PromptKind::String),
];

pub fn list_models(local_config: &ErnieConfig, client_index: usize) -> Vec<Model> {
pub fn list_models(local_config: &ErnieConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, _)| Model::new(client_index, client_name, name))
.map(|(name, _)| Model::new(client_name, name))
.collect()
}

Expand All @@ -79,7 +79,7 @@ impl ErnieClient {
let (_, chat_endpoint) = MODELS
.iter()
.find(|(v, _)| v == &model)
.ok_or_else(|| anyhow!("Miss Model '{}' in {}", model, self.model.client_name))?;
.ok_or_else(|| anyhow!("Miss Model '{}'", self.model.id()))?;

let url = format!("{API_BASE}{chat_endpoint}?access_token={}", unsafe {
&ACCESS_TOKEN
Expand Down
4 changes: 2 additions & 2 deletions src/client/localai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ impl LocalAIClient {
),
];

pub fn list_models(local_config: &LocalAIConfig, client_index: usize) -> Vec<Model> {
pub fn list_models(local_config: &LocalAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);

local_config
.models
.iter()
.map(|v| {
Model::new(client_index, client_name, &v.name)
Model::new(client_name, &v.name)
.set_max_tokens(v.max_tokens)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
Expand Down
6 changes: 2 additions & 4 deletions src/client/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub type TokensCountFactors = (usize, usize); // (per-messages, bias)

#[derive(Debug, Clone)]
pub struct Model {
pub client_index: usize,
pub client_name: String,
pub name: String,
pub max_tokens: Option<usize>,
Expand All @@ -17,14 +16,13 @@ pub struct Model {

impl Default for Model {
fn default() -> Self {
Model::new(0, "", "")
Model::new("", "")
}
}

impl Model {
pub fn new(client_index: usize, client_name: &str, name: &str) -> Self {
pub fn new(client_name: &str, name: &str) -> Self {
Self {
client_index,
client_name: client_name.into(),
name: name.into(),
max_tokens: None,
Expand Down
4 changes: 2 additions & 2 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ impl OpenAIClient {
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

pub fn list_models(local_config: &OpenAIConfig, client_index: usize) -> Vec<Model> {
pub fn list_models(local_config: &OpenAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| {
Model::new(client_index, client_name, name)
Model::new(client_name, name)
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
Expand Down
16 changes: 6 additions & 10 deletions src/client/palm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};

const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta2";
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta2/models/";

const MODELS: [(&str, usize, &str); 1] = [("chat-bison-001", 4096, "/models/chat-bison-001")];
const MODELS: [(&str, usize); 1] = [("chat-bison-001", 4096)];

const TOKENS_COUNT_FACTORS: TokensCountFactors = (3, 8);

Expand Down Expand Up @@ -49,12 +49,12 @@ impl PaLMClient {
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

pub fn list_models(local_config: &PaLMConfig, client_index: usize) -> Vec<Model> {
pub fn list_models(local_config: &PaLMConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, _)| {
Model::new(client_index, client_name, name)
.map(|(name, max_tokens)| {
Model::new(client_name, name)
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
Expand All @@ -67,12 +67,8 @@ impl PaLMClient {
let body = build_body(data, self.model.name.clone());

let model = self.model.name.clone();
let (_, _, endpoint) = MODELS
.iter()
.find(|(v, _, _)| v == &model)
.ok_or_else(|| anyhow!("Miss Model '{}' in {}", model, self.model.client_name))?;

let url = format!("{API_BASE}{endpoint}:generateMessage?key={}", api_key);
let url = format!("{API_BASE}{}:generateMessage?key={}", model, api_key);

let builder = client.post(url).json(&body);

Expand Down
4 changes: 2 additions & 2 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ impl QianwenClient {
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

pub fn list_models(local_config: &QianwenConfig, client_index: usize) -> Vec<Model> {
pub fn list_models(local_config: &QianwenConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| Model::new(client_index, client_name, name).set_max_tokens(Some(max_tokens)))
.map(|(name, max_tokens)| Model::new(client_name, name).set_max_tokens(Some(max_tokens)))
.collect()
}

Expand Down

0 comments on commit 9a8b302

Please sign in to comment.