Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bot #579

Merged
merged 12 commits into from
Jun 11, 2024
11 changes: 0 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ hmac = "0.12.1"
aws-smithy-eventstream = "0.60.4"
urlencoding = "2.1.3"
unicode-segmentation = "1.11.0"
num_cpus = "1.16.0"
threadpool = "1.8.1"
json-patch = { version = "2.0.0", default-features = false }
bitflags = "2.5.0"
path-absolutize = "3.1.1"
Expand Down
10 changes: 8 additions & 2 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ summarize_prompt: 'Summarize the discussion briefly in 200 words or less to use
summary_prompt: 'This is a summary of the chat history as a recap: '

# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt
left_prompt: '{color.green}{?session {session}{?role /}}{role}{?rag #{rag}{color.cyan}{?session )}{!session >}{color.reset} '
left_prompt: '{color.green}{?session {?bot {bot}#}{session}{?role /}}{!session {?bot {bot}}}{role}{?rag @{rag}}{color.cyan}{?session )}{!session >}{color.reset} '
right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'

clients:
Expand Down Expand Up @@ -247,4 +247,10 @@ clients:
- type: openai-compatible
name: together
api_base: https://api.together.xyz/v1
api_key: xxx # ENV: {client}_API_KEY
api_key: xxx # ENV: {client}_API_KEY

bots:
- name: todo-sh
model: null
temperature: null
top_p: null
23 changes: 18 additions & 5 deletions src/client/model.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::{
list_chat_models,
message::{Message, MessageContent},
EmbeddingsData,
};

use crate::config::Config;
use crate::utils::{estimate_token_length, format_option_value};

use anyhow::{bail, Result};
Expand Down Expand Up @@ -41,21 +43,28 @@ impl Model {
.collect()
}

pub fn find(models: &[&Self], value: &str) -> Option<Self> {
pub fn retrieve(config: &Config, model_id: &str) -> Result<Self> {
match Self::find(&list_chat_models(config), model_id) {
Some(v) => Ok(v),
None => bail!("Invalid model '{model_id}'"),
}
}

pub fn find(models: &[&Self], model_id: &str) -> Option<Self> {
let mut model = None;
let (client_name, model_name) = match value.split_once(':') {
let (client_name, model_name) = match model_id.split_once(':') {
Some((client_name, model_name)) => {
if model_name.is_empty() {
(client_name, None)
} else {
(client_name, Some(model_name))
}
}
None => (value, None),
None => (model_id, None),
};
match model_name {
Some(model_name) => {
if let Some(found) = models.iter().find(|v| v.id() == value) {
if let Some(found) = models.iter().find(|v| v.id() == model_id) {
model = Some((*found).clone());
} else if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
let mut found = (*found).clone();
Expand All @@ -73,7 +82,11 @@ impl Model {
}

pub fn id(&self) -> String {
format!("{}:{}", self.client_name, self.data.name)
if self.data.name.is_empty() {
self.client_name.to_string()
} else {
format!("{}:{}", self.client_name, self.data.name)
}
}

pub fn client_name(&self) -> &str {
Expand Down
255 changes: 255 additions & 0 deletions src/config/bot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
use super::*;

use crate::{
client::Model,
function::{Functions, FUNCTION_ALL_MATCHER},
};

use anyhow::{Context, Result};
use std::{fs::read_to_string, path::Path};

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize)]
pub struct Bot {
name: String,
config: BotConfig,
definition: BotDefinition,
#[serde(skip)]
functions: Functions,
#[serde(skip)]
rag: Option<Arc<Rag>>,
#[serde(skip)]
model: Model,
}

impl Bot {
pub async fn init(
config: &GlobalConfig,
name: &str,
abort_signal: AbortSignal,
) -> Result<Self> {
let definition_path = Config::bot_definition_file(name)?;
let functions_path = Config::bot_functions_file(name)?;
let rag_path = Config::bot_rag_file(name)?;
let embeddings_dir = Config::bot_embeddings_dir(name)?;
let definition = BotDefinition::load(&definition_path)?;
let functions = if functions_path.exists() {
Functions::init(&functions_path)?
} else {
Functions::default()
};
let bot_config = config
.read()
.bots
.iter()
.find(|v| v.name == name)
.cloned()
.unwrap_or_else(|| BotConfig::new(name));
let model = {
let config = config.read();
match bot_config.model_id.as_ref() {
Some(model_id) => Model::retrieve(&config, model_id)?,
None => config.current_model().clone(),
}
};

let render_options = config.read().get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
println!("{}", markdown_render.render(&definition.banner()));

let rag = if rag_path.exists() {
Some(Arc::new(Rag::load(config, "rag", &rag_path)?))
} else if embeddings_dir.is_dir() {
println!("The bot has an embeddings directory, RAG is initializing...");
let ans = Confirm::new("The bot attached embeddings, init RAG?")
.with_default(true)
.prompt()?;
if ans {
let doc_path = embeddings_dir.display().to_string();
Some(Arc::new(
Rag::init(config, "rag", &rag_path, &[doc_path], abort_signal).await?,
))
} else {
None
}
} else {
None
};

Ok(Self {
name: name.to_string(),
config: bot_config,
definition,
functions,
rag,
model,
})
}

pub fn export(&self) -> Result<String> {
let mut value = serde_json::json!(self);
value["functions_dir"] = Config::bot_functions_dir(&self.name)?
.display()
.to_string()
.into();
value["config_dir"] = Config::bot_config_dir(&self.name)?
.display()
.to_string()
.into();
let data = serde_yaml::to_string(&value)?;
Ok(data)
}

pub fn name(&self) -> &str {
&self.name
}

pub fn functions(&self) -> &Functions {
&self.functions
}

pub fn definition(&self) -> &BotDefinition {
&self.definition
}

pub fn rag(&self) -> Option<Arc<Rag>> {
self.rag.clone()
}
}

impl RoleLike for Bot {
fn to_role(&self) -> Role {
let mut role = Role::new("", &self.definition.instructions);
role.sync(self);
role
}

fn model(&self) -> &Model {
&self.model
}

fn temperature(&self) -> Option<f64> {
self.config.temperature
}

fn top_p(&self) -> Option<f64> {
self.config.top_p
}

fn function_matcher(&self) -> Option<String> {
if self.functions.is_empty() {
None
} else {
Some(FUNCTION_ALL_MATCHER.into())
}
}

fn set_model(&mut self, model: &Model) {
self.config.model_id = Some(model.id());
self.model = model.clone();
}

fn set_temperature(&mut self, value: Option<f64>) {
self.config.temperature = value;
}

fn set_top_p(&mut self, value: Option<f64>) {
self.config.top_p = value;
}

fn set_function_matcher(&mut self, _value: Option<String>) {}
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct BotConfig {
pub name: String,
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}

impl BotConfig {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
..Default::default()
}
}
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct BotDefinition {
pub name: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub version: String,
pub instructions: String,
#[serde(default)]
pub conversation_starters: Vec<String>,
}

impl BotDefinition {
pub fn load(path: &Path) -> Result<Self> {
let contents = read_to_string(path)
.with_context(|| format!("Failed to read bot index file at '{}'", path.display()))?;
let definition: Self = serde_yaml::from_str(&contents)
.with_context(|| format!("Failed to load bot at '{}'", path.display()))?;
Ok(definition)
}

fn banner(&self) -> String {
let BotDefinition {
name,
description,
version,
conversation_starters,
..
} = self;
let starters = if conversation_starters.is_empty() {
String::new()
} else {
let starters = conversation_starters
.iter()
.map(|v| format!("- {v}"))
.collect::<Vec<_>>()
.join("\n");
format!(
r#"

**Conversation Starters**
{starters}"#
)
};
format!(
r#"# {name} {version}
{description}{starters}
"#
)
}
}

pub fn list_bots() -> Vec<String> {
list_bots_impl().unwrap_or_default()
}

fn list_bots_impl() -> Result<Vec<String>> {
let base_dir = Config::functions_dir()?;
let contents = read_to_string(base_dir.join("bots.txt"))?;
let bots = contents
.split('\n')
.filter_map(|line| {
let line = line.trim();
if line.is_empty() {
None
} else {
Some(line.to_string())
}
})
.collect();
Ok(bots)
}
Loading