Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rag load websites #655

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ rag_min_score_rerank: 0 # Specifies the minimum relevance sc
rag_document_loaders:
# You can add more loaders, here is the syntax:
# <file-extension>: <command-to-load-the-file>
pdf: 'pdftotext $1 -' # Load .pdf file
pdf: 'pdftotext $1 -' # Load .pdf file, see https://poppler.freedesktop.org
docx: 'pandoc --to plain $1' # Load .docx file
url: 'curl -fsSL $1' # Load url
# recursive_url: 'crawler $1 $2' # Load websites

# Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs
rag_template: |
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
let client = input.create_client()?;
config.write().before_chat_completion(&input)?;
let ret = if *IS_STDOUT_TERMINAL {
let (stop_spinner_tx, _) = run_spinner("Generating").await;
let spinner = create_spinner("Generating").await;
let ret = client.chat_completions(input.clone()).await;
let _ = stop_spinner_tx.send(());
spinner.stop();
ret
} else {
client.chat_completions(input.clone()).await
Expand Down
201 changes: 180 additions & 21 deletions src/rag/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,43 @@ use super::*;

use anyhow::{bail, Context, Result};
use async_recursion::async_recursion;
use std::{collections::HashMap, fs::read_to_string, path::Path};
use serde_json::Value;
use std::{collections::HashMap, env, fs::read_to_string, path::Path};

pub fn load_file(
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";

pub fn load(
loaders: &HashMap<String, String>,
path: &str,
loader_name: &str,
) -> Result<Vec<RagDocument>> {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
None => load_plain(path),
if loader_name == RECURSIVE_URL_LOADER {
let loader_command = loaders
.get(loader_name)
.with_context(|| format!("RAG document loader '{loader_name}' not configured"))?;
let contents = run_loader_command(path, loader_name, loader_command)?;
let output = match parse_json_documents(&contents) {
Some(v) => v,
None => vec![RagDocument::new(contents)],
};
Ok(output)
} else {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
None => load_plain(path, loader_name),
}
}
}

fn load_plain(path: &str) -> Result<Vec<RagDocument>> {
fn load_plain(path: &str, loader_name: &str) -> Result<Vec<RagDocument>> {
let contents = read_to_string(path)?;
let document = RagDocument::new(contents);
if loader_name == "json" {
if let Some(documents) = parse_json_documents(&contents) {
return Ok(documents);
}
}
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document])
}

Expand All @@ -26,29 +47,135 @@ fn load_with_command(
loader_name: &str,
loader_command: &str,
) -> Result<Vec<RagDocument>> {
let cmd_args = shell_words::split(loader_command)
.with_context(|| anyhow!("Invalid rag loader '{loader_name}': `{loader_command}`"))?;
let contents = run_loader_command(path, loader_name, loader_command)?;
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document])
}

fn run_loader_command(path: &str, loader_name: &str, loader_command: &str) -> Result<String> {
let cmd_args = shell_words::split(loader_command).with_context(|| {
anyhow!("Invalid rag document loader '{loader_name}': `{loader_command}`")
})?;
let mut use_stdout = true;
let outpath = env::temp_dir()
.join(format!("aichat-{}", sha256(path)))
.display()
.to_string();
let cmd_args: Vec<_> = cmd_args
.into_iter()
.map(|v| if v == "$1" { path.to_string() } else { v })
.map(|mut v| {
if v.contains("$1") {
v = v.replace("$1", path);
}
if v.contains("$2") {
use_stdout = false;
v = v.replace("$2", &outpath);
}
v
})
.collect();
let cmd_eval = shell_words::join(&cmd_args);
debug!("run `{cmd_eval}`");
let (cmd, args) = cmd_args.split_at(1);
let cmd = &cmd[0];
let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
if use_stdout {
let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
}
Ok(stdout)
} else {
let status = run_command(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
if status != 0 {
bail!("The command `{cmd_eval}` exited with non-zero.")
}
let contents =
read_to_string(&outpath).context("Failed to read file generated by the loader")?;
Ok(contents)
}
}

fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
let value: Value = serde_json::from_str(data).ok()?;
let items = match value {
Value::Array(v) => v,
_ => return None,
};
if items.is_empty() {
return None;
}
match &items[0] {
Value::String(_) => {
let documents: Vec<_> = items
.into_iter()
.flat_map(|item| {
if let Value::String(content) = item {
Some(RagDocument::new(content))
} else {
None
}
})
.collect();
Some(documents)
}
Value::Object(obj) => {
let key = [
"page_content",
"pageContent",
"content",
"html",
"markdown",
"text",
"data",
]
.into_iter()
.map(|v| v.to_string())
.find(|key| obj.get(key).and_then(|v| v.as_str()).is_some())?;
let documents: Vec<_> = items
.into_iter()
.flat_map(|item| {
if let Value::Object(mut obj) = item {
if let Some(page_content) = obj.get(&key).and_then(|v| v.as_str()) {
let page_content = page_content.to_string();
obj.remove(&key);
let metadata: IndexMap<_, _> = obj
.into_iter()
.map(|(k, v)| {
if let Value::String(v) = v {
(k, v)
} else {
(k, v.to_string())
}
})
.collect();
return Some(RagDocument {
page_content,
metadata,
});
}
}
None
})
.collect();
if documents.is_empty() {
None
} else {
Some(documents)
}
}
_ => None,
}
let document = RagDocument::new(stdout);
Ok(vec![document])
}

pub fn parse_glob(path_str: &str) -> Result<(String, Vec<String>)> {
Expand Down Expand Up @@ -146,4 +273,36 @@ mod tests {
("C:\\dir".into(), vec!["md".into(), "txt".into()])
);
}

#[test]
fn test_parse_json_documents() {
let data = r#"["foo", "bar"]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo"), RagDocument::new("bar")]
);

let data = r#"[{"content": "foo"}, {"content": "bar"}]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo"), RagDocument::new("bar")]
);

let mut metadata = IndexMap::new();
metadata.insert("k1".into(), "1".into());
let data = r#"[{"k1": 1, "data": "foo" }]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo").with_metadata(metadata.clone())]
);

let data = r#""hello""#;
assert!(parse_json_documents(data).is_none());

let data = r#"{"key":"value"}"#;
assert!(parse_json_documents(data).is_none());

let data = r#"[{"key":"value"}]"#;
assert!(parse_json_documents(data).is_none());
}
}
Loading