Skip to content

Commit

Permalink
feat: rag load websites (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Jun 26, 2024
1 parent 03b4003 commit 5985551
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 174 deletions.
4 changes: 3 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ rag_min_score_rerank: 0 # Specifies the minimum relevance sc
rag_document_loaders:
# You can add more loaders, here is the syntax:
# <file-extension>: <command-to-load-the-file>
pdf: 'pdftotext $1 -' # Load .pdf file
pdf: 'pdftotext $1 -' # Load .pdf file, see https://poppler.freedesktop.org
docx: 'pandoc --to plain $1' # Load .docx file
url: 'curl -fsSL $1' # Load url
# recursive_url: 'crawler $1 $2' # Load websites

# Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs
rag_template: |
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
let client = input.create_client()?;
config.write().before_chat_completion(&input)?;
let ret = if *IS_STDOUT_TERMINAL {
let (stop_spinner_tx, _) = run_spinner("Generating").await;
let spinner = create_spinner("Generating").await;
let ret = client.chat_completions(input.clone()).await;
let _ = stop_spinner_tx.send(());
spinner.stop();
ret
} else {
client.chat_completions(input.clone()).await
Expand Down
201 changes: 180 additions & 21 deletions src/rag/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,43 @@ use super::*;

use anyhow::{bail, Context, Result};
use async_recursion::async_recursion;
use std::{collections::HashMap, fs::read_to_string, path::Path};
use serde_json::Value;
use std::{collections::HashMap, env, fs::read_to_string, path::Path};

pub fn load_file(
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";

pub fn load(
loaders: &HashMap<String, String>,
path: &str,
loader_name: &str,
) -> Result<Vec<RagDocument>> {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
None => load_plain(path),
if loader_name == RECURSIVE_URL_LOADER {
let loader_command = loaders
.get(loader_name)
.with_context(|| format!("RAG document loader '{loader_name}' not configured"))?;
let contents = run_loader_command(path, loader_name, loader_command)?;
let output = match parse_json_documents(&contents) {
Some(v) => v,
None => vec![RagDocument::new(contents)],
};
Ok(output)
} else {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
None => load_plain(path, loader_name),
}
}
}

fn load_plain(path: &str) -> Result<Vec<RagDocument>> {
fn load_plain(path: &str, loader_name: &str) -> Result<Vec<RagDocument>> {
let contents = read_to_string(path)?;
let document = RagDocument::new(contents);
if loader_name == "json" {
if let Some(documents) = parse_json_documents(&contents) {
return Ok(documents);
}
}
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document])
}

Expand All @@ -26,29 +47,135 @@ fn load_with_command(
loader_name: &str,
loader_command: &str,
) -> Result<Vec<RagDocument>> {
let cmd_args = shell_words::split(loader_command)
.with_context(|| anyhow!("Invalid rag loader '{loader_name}': `{loader_command}`"))?;
let contents = run_loader_command(path, loader_name, loader_command)?;
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document])
}

fn run_loader_command(path: &str, loader_name: &str, loader_command: &str) -> Result<String> {
let cmd_args = shell_words::split(loader_command).with_context(|| {
anyhow!("Invalid rag document loader '{loader_name}': `{loader_command}`")
})?;
let mut use_stdout = true;
let outpath = env::temp_dir()
.join(format!("aichat-{}", sha256(path)))
.display()
.to_string();
let cmd_args: Vec<_> = cmd_args
.into_iter()
.map(|v| if v == "$1" { path.to_string() } else { v })
.map(|mut v| {
if v.contains("$1") {
v = v.replace("$1", path);
}
if v.contains("$2") {
use_stdout = false;
v = v.replace("$2", &outpath);
}
v
})
.collect();
let cmd_eval = shell_words::join(&cmd_args);
debug!("run `{cmd_eval}`");
let (cmd, args) = cmd_args.split_at(1);
let cmd = &cmd[0];
let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
if use_stdout {
let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
}
Ok(stdout)
} else {
let status = run_command(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
if status != 0 {
bail!("The command `{cmd_eval}` exited with non-zero.")
}
let contents =
read_to_string(&outpath).context("Failed to read file generated by the loader")?;
Ok(contents)
}
}

fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
let value: Value = serde_json::from_str(data).ok()?;
let items = match value {
Value::Array(v) => v,
_ => return None,
};
if items.is_empty() {
return None;
}
match &items[0] {
Value::String(_) => {
let documents: Vec<_> = items
.into_iter()
.flat_map(|item| {
if let Value::String(content) = item {
Some(RagDocument::new(content))
} else {
None
}
})
.collect();
Some(documents)
}
Value::Object(obj) => {
let key = [
"page_content",
"pageContent",
"content",
"html",
"markdown",
"text",
"data",
]
.into_iter()
.map(|v| v.to_string())
.find(|key| obj.get(key).and_then(|v| v.as_str()).is_some())?;
let documents: Vec<_> = items
.into_iter()
.flat_map(|item| {
if let Value::Object(mut obj) = item {
if let Some(page_content) = obj.get(&key).and_then(|v| v.as_str()) {
let page_content = page_content.to_string();
obj.remove(&key);
let metadata: IndexMap<_, _> = obj
.into_iter()
.map(|(k, v)| {
if let Value::String(v) = v {
(k, v)
} else {
(k, v.to_string())
}
})
.collect();
return Some(RagDocument {
page_content,
metadata,
});
}
}
None
})
.collect();
if documents.is_empty() {
None
} else {
Some(documents)
}
}
_ => None,
}
let document = RagDocument::new(stdout);
Ok(vec![document])
}

pub fn parse_glob(path_str: &str) -> Result<(String, Vec<String>)> {
Expand Down Expand Up @@ -146,4 +273,36 @@ mod tests {
("C:\\dir".into(), vec!["md".into(), "txt".into()])
);
}

#[test]
fn test_parse_json_documents() {
let data = r#"["foo", "bar"]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo"), RagDocument::new("bar")]
);

let data = r#"[{"content": "foo"}, {"content": "bar"}]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo"), RagDocument::new("bar")]
);

let mut metadata = IndexMap::new();
metadata.insert("k1".into(), "1".into());
let data = r#"[{"k1": 1, "data": "foo" }]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo").with_metadata(metadata.clone())]
);

let data = r#""hello""#;
assert!(parse_json_documents(data).is_none());

let data = r#"{"key":"value"}"#;
assert!(parse_json_documents(data).is_none());

let data = r#"[{"key":"value"}]"#;
assert!(parse_json_documents(data).is_none());
}
}
Loading

0 comments on commit 5985551

Please sign in to comment.