Skip to content

Commit

Permalink
feat: custom rag document loaders (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Jun 26, 2024
1 parent 34a6d13 commit 95bad97
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 73 deletions.
19 changes: 0 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ json-patch = { version = "2.0.0", default-features = false }
bitflags = "2.5.0"
path-absolutize = "3.1.1"
hnsw_rs = "0.3.0"
which = "6.0.1"
rayon = "1.10.0"

[dependencies.reqwest]
Expand Down
23 changes: 15 additions & 8 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ agents:
dangerously_functions_filter: null

# ---- RAG ----
rag_embedding_model: null # Specifies the embedding model to use
rag_reranker_model: null # Specifies the rerank model to use
rag_top_k: 4 # Specifies the number of documents to retrieve
rag_chunk_size: null # Specifies the chunk size
rag_chunk_overlap: null # Specifies the chunk overlap
rag_min_score_vector_search: 0 # Specifies the minimum relevance score for vector-based searching
rag_min_score_keyword_search: 0 # Specifies the minimum relevance score for keyword-based searching
rag_min_score_rerank: 0 # Specifies the minimum relevance score for reranking
rag_embedding_model: null # Specifies the embedding model to use
rag_reranker_model: null # Specifies the rerank model to use
rag_top_k: 4 # Specifies the number of documents to retrieve
rag_chunk_size: null # Specifies the chunk size
rag_chunk_overlap: null # Specifies the chunk overlap
rag_min_score_vector_search: 0 # Specifies the minimum relevance score for vector-based searching
rag_min_score_keyword_search: 0 # Specifies the minimum relevance score for keyword-based searching
rag_min_score_rerank: 0 # Specifies the minimum relevance score for reranking
# Defines document loaders
rag_document_loaders:
# You can add more loaders, here is the syntax:
# <file-extension>: <command-to-load-the-file>
pdf: 'pdftotext $1 -' # Load .pdf file
docx: 'pandoc --to plain $1' # Load .docx file

# Defines the query structure using variables like __CONTEXT__ and __INPUT__ to tailor searches to specific needs
rag_template: |
Use the following context as your learned knowledge, inside <context></context> XML tags.
Expand Down
17 changes: 17 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ pub struct Config {
pub rag_min_score_vector_search: f32,
pub rag_min_score_keyword_search: f32,
pub rag_min_score_rerank: f32,
#[serde(default)]
pub rag_document_loaders: HashMap<String, String>,
pub rag_template: Option<String>,

pub highlight: bool,
Expand Down Expand Up @@ -174,6 +176,7 @@ impl Default for Config {
rag_min_score_vector_search: 0.0,
rag_min_score_keyword_search: 0.0,
rag_min_score_rerank: 0.0,
rag_document_loaders: Default::default(),
rag_template: None,

save_session: None,
Expand Down Expand Up @@ -229,6 +232,7 @@ impl Config {
config.setup_model()?;
config.setup_highlight();
config.setup_light_theme()?;
config.setup_rag_document_loaders();

Ok(config)
}
Expand Down Expand Up @@ -1440,6 +1444,19 @@ impl Config {
};
Ok(())
}

fn setup_rag_document_loaders(&mut self) {
[
("pdf", "pdftotext $1 -"),
("docx", "pandoc --to plain $1"),
("url", "curl -fsSL $1"),
]
.into_iter()
.for_each(|(k, v)| {
let (k, v) = (k.to_string(), v.to_string());
self.rag_document_loaders.entry(k).or_insert(v);
});
}
}

#[derive(Debug, Clone, Deserialize, Default)]
Expand Down
77 changes: 36 additions & 41 deletions src/rag/loader.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use super::*;

use anyhow::{bail, Result};
use anyhow::{bail, Context, Result};
use async_recursion::async_recursion;
use lazy_static::lazy_static;
use std::{fs::read_to_string, path::Path};
use which::which;
use std::{collections::HashMap, fs::read_to_string, path::Path};

lazy_static! {
static ref EXIST_PANDOC: bool = which("pandoc").is_ok();
static ref EXIST_PDFTOTEXT: bool = which("pdftotext").is_ok();
}

pub fn load(path: &str, extension: &str) -> Result<Vec<RagDocument>> {
match extension {
"docx" | "epub" => load_with_pandoc(path),
"pdf" => load_with_pdftotext(path),
_ => load_plain(path),
pub fn load_file(
loaders: &HashMap<String, String>,
path: &str,
loader_name: &str,
) -> Result<Vec<RagDocument>> {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
None => load_plain(path),
}
}

Expand All @@ -25,21 +21,33 @@ fn load_plain(path: &str) -> Result<Vec<RagDocument>> {
Ok(vec![document])
}

fn load_with_pdftotext(path: &str) -> Result<Vec<RagDocument>> {
if !*EXIST_PDFTOTEXT {
bail!("Need to install pdftotext (part of the poppler package) to load the file.")
}
let contents = run_external_tool("pdftotext", &[path, "-"])?;
let document = RagDocument::new(contents);
Ok(vec![document])
}

fn load_with_pandoc(path: &str) -> Result<Vec<RagDocument>> {
if !*EXIST_PANDOC {
bail!("Need to install pandoc to load the file.")
fn load_with_command(
path: &str,
loader_name: &str,
loader_command: &str,
) -> Result<Vec<RagDocument>> {
let cmd_args = shell_words::split(loader_command)
.with_context(|| anyhow!("Invalid rag loader '{loader_name}': `{loader_command}`"))?;
let cmd_args: Vec<_> = cmd_args
.into_iter()
.map(|v| if v == "$1" { path.to_string() } else { v })
.collect();
let cmd_eval = shell_words::join(&cmd_args);
let (cmd, args) = cmd_args.split_at(1);
let cmd = &cmd[0];
let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
}
let contents = run_external_tool("pandoc", &["--to", "plain", path])?;
let document = RagDocument::new(contents);
let document = RagDocument::new(stdout);
Ok(vec![document])
}

Expand Down Expand Up @@ -114,19 +122,6 @@ fn is_valid_extension(suffixes: Option<&Vec<String>>, path: &Path) -> bool {
true
}

fn run_external_tool(cmd: &str, args: &[&str]) -> Result<String> {
let (success, stdout, stderr) = run_command_with_output(cmd, args, None)?;
if success {
return Ok(stdout);
}
let err = if !stderr.is_empty() {
stderr
} else {
format!("`{cmd}` exited with non-zero.")
};
bail!("{err}")
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 9 additions & 4 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use inquire::{required, validator::Validation, Select, Text};
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::{fmt::Debug, io::BufReader, path::Path};
use tokio::sync::mpsc;

Expand Down Expand Up @@ -59,9 +60,10 @@ impl Rag {
paths = add_doc_paths()?;
};
debug!("doc paths: {paths:?}");
let loaders = config.read().rag_document_loaders.clone();
let (stop_spinner_tx, set_spinner_message_tx) = run_spinner("Starting").await;
tokio::select! {
ret = rag.add_paths(&paths, Some(set_spinner_message_tx)) => {
ret = rag.add_paths(loaders, &paths, Some(set_spinner_message_tx)) => {
let _ = stop_spinner_tx.send(());
ret?;
}
Expand Down Expand Up @@ -221,6 +223,7 @@ impl Rag {

pub async fn add_paths<T: AsRef<Path>>(
&mut self,
loaders: HashMap<String, String>,
paths: &[T],
progress_tx: Option<mpsc::UnboundedSender<String>>,
) -> Result<()> {
Expand Down Expand Up @@ -260,13 +263,15 @@ impl Rag {
self.data.chunk_overlap,
&separator,
);
let documents = load(&path, &extension)
let documents = load_file(&loaders, &path, &extension)
.with_context(|| format!("Failed to load file at '{path}'"))?;
let split_options = SplitterChunkHeaderOptions::default().with_chunk_header(&format!(
"<document_metadata>\npath: {path}\n</document_metadata>\n\n"
));
let documents = splitter.split_documents(&documents, &split_options);
rag_files.push(RagFile { path, documents });
if !documents.is_empty() {
let documents = splitter.split_documents(&documents, &split_options);
rag_files.push(RagFile { path, documents });
}
progress(
&progress_tx,
format!("Loading files [{}/{file_paths_len}]", rag_files.len()),
Expand Down

0 comments on commit 95bad97

Please sign in to comment.