Skip to content

Commit

Permalink
Update binary RAG pipeline (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Apr 17, 2024
1 parent 1cda053 commit 46f6770
Show file tree
Hide file tree
Showing 15 changed files with 810 additions and 51 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

### Fixed

## [0.20.0]

### Added
- Added a few new open-weights models hosted by Fireworks.ai to the registry (DBRX Instruct, Mixtral 8x22b Instruct, Qwen 72b). If you're curious about how well they work, try them!
- Added basic support for observability downstream. Created custom callback infrastructure with `initialize_tracer` and `finalize_tracer` and dedicated types are `TracerMessage` and `TracerMessageLike`. See `?TracerMessage` for more information and the corresponding `aigenerate` docstring.
- Added `MultiCandidateChunks` which can hold candidates for retrieval across many indices (it's a flat structure to be similar to `CandidateChunks` and easy to reason about).
- JSON serialization support extended for `RAGResult`, `CandidateChunks`, and `MultiCandidateChunks` to increase observability of RAG systems
- Added a new search refiner `TavilySearchRefiner` - it will search the web via Tavily API to try to improve on the RAG answer (see `?refine!`).
- Introduced a few small utilities for manipulation of nested kwargs (necessary for RAG pipelines), check out `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.

### Updated
- [BREAKING] change to `CandidateChunks` where it's no longer allowed to be nested (ie, `cc.positions` being a list of several `CandidateChunks`). This is a breaking change for the `RAGTools` module only. We have introduced a new `MultiCandidateChunks` types that can refer to `CandidateChunks` across many indices.
- Changed default model for `RAGTools.CohereReranker` to "cohere-rerank-english-v3.0".

### Fixed
- `wrap_string` utility now correctly splits only on spaces. Previously it would split on newlines, which would remove natural formatting of prompts/messages when displayed via `pprint`

## [0.19.0]

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.19.0"
version = "0.20.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
6 changes: 3 additions & 3 deletions src/Experimental/Experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ Contains:
"""
module Experimental

export APITools
include("APITools/APITools.jl")

export RAGTools
include("RAGTools/RAGTools.jl")

export AgentTools
include("AgentTools/AgentTools.jl")

export APITools
include("APITools/APITools.jl")

end # module Experimental
4 changes: 4 additions & 0 deletions src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ module RAGTools
using PromptingTools
using PromptingTools: pprint, AbstractMessage
using HTTP, JSON3
using JSON3: StructTypes
using AbstractTrees
using AbstractTrees: PreOrderDFS
const PT = PromptingTools
using PromptingTools.Experimental.APITools: create_websearch

# reexport
export pprint

## export trigrams, trigrams_hashed, text_to_trigrams, text_to_trigrams_hashed
## export STOPWORDS, tokenize, split_into_code_and_sentences
# export merge_kwargs_nested
export getpropertynested, setpropertynested
include("utils.jl")

# eg, cohere_api
Expand Down
117 changes: 116 additions & 1 deletion src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ Refines the answer using the same context previously provided via the provided p
"""
struct SimpleRefiner <: AbstractRefiner end

"""
TavilySearchRefiner <: AbstractRefiner
Refines the answer by executing a web search using the Tavily API. This method aims to enhance the answer's accuracy and relevance by incorporating information retrieved from the web.
"""
struct TavilySearchRefiner <: AbstractRefiner end

function refine!(
refiner::AbstractRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
kwargs...)
Expand Down Expand Up @@ -223,6 +230,112 @@ function refine!(
return result
end

"""
refine!(
refiner::TavilySearchRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
max_results::Integer = 5,
include_domains::AbstractVector{<:AbstractString} = String[],
exclude_domains::AbstractVector{<:AbstractString} = String[],
template::Symbol = :RAGWebSearchRefiner,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Refines the answer by executing a web search using the Tavily API. This method aims to enhance the answer's accuracy and relevance by incorporating information retrieved from the web.
Note: The web results and web answer (if requested) will be added to the context and sources!
# Returns
- Mutated `result` with `result.final_answer` and the full conversation saved in `result.conversations[:final_answer]`.
- In addition, the web results and web answer (if requested) are appended to the `result.context` and `result.sources` for correct highlighting and verification.
# Arguments
- `refiner::TavilySearchRefiner`: The method to use for refining the answer. Uses `aigenerate` with a web search template.
- `index::AbstractChunkIndex`: The index containing chunks and sources.
- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for.
- `model::AbstractString`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`.
- `include_answer::Bool`: If `true`, includes the answer from Tavily in the web search.
- `max_results::Integer`: The maximum number of results to return.
- `include_domains::AbstractVector{<:AbstractString}`: A list of domains to include in the search results. Default is an empty list.
- `exclude_domains::AbstractVector{<:AbstractString}`: A list of domains to exclude from the search results. Default is an empty list.
- `verbose::Bool`: If `true`, enables verbose logging.
- `template::Symbol`: The template to use for the `aigenerate` function. Defaults to `:RAGWebSearchRefiner`.
- `cost_tracker`: An atomic counter to track the cost of the operation.
# Example
```julia
refiner!(TavilySearchRefiner(), index, result)
# See result.final_answer or pprint(result)
```
To enable this refiner in a full RAG pipeline, simply swap the component in the config:
```julia
cfg = RT.RAGConfig()
cfg.generator.refiner = RT.TavilySearchRefiner()
result = airag(cfg, index; question, return_all = true)
pprint(result)
```
"""
function refine!(
refiner::TavilySearchRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
max_results::Integer = 5,
include_domains::AbstractVector{<:AbstractString} = String[],
exclude_domains::AbstractVector{<:AbstractString} = String[],
template::Symbol = :RAGWebSearchRefiner,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)

## Checks
placeholders = only(aitemplates(template)).variables # only one template should be found
@assert (:query in placeholders)&&(:answer in placeholders) &&
(:search_results in placeholders) "Provided RAG Template $(template) is not suitable. It must have placeholders: `query`, `answer` and `search_results`."
##
(; answer, question) = result
## execute Tavily web search and format it
r = create_websearch(
question; include_answer, max_results, include_domains,
exclude_domains)
web_summary = get(r.response, "answer", "")
web_raw = get(r.response, "results", [])
web_sources = ["TOOL(TavilySearch): " * get(res, "url", "") for res in web_raw]
web_content = join(
["$(i). TavilySearch: " * get(res, "content", "")
for (i, res) in enumerate(web_raw)],
"\n\n")
search_results = """
Web Results Summary: $(web_summary)
**Raw Results:**
$(web_content)
"""
##
conv = aigenerate(template; query = question, search_results,
answer, model, verbose = false,
return_all = true,
kwargs...)
msg = conv[end]
result.final_answer = strip(msg.content)
result.conversations[:final_answer] = conv

## Attache the web sources to the context + sources (for reference)
result.sources = vcat(result.sources, web_sources)
result.context = vcat(result.context, web_content)

## Increment the cost
Threads.atomic_add!(cost_tracker, msg.cost)
verbose &&
@info "Done refining the answer. Cost: \$$(round(msg.cost,digits=3))"

return result
end

"""
NoPostprocessor <: AbstractPostprocessor
Expand Down Expand Up @@ -446,7 +559,7 @@ Eg, use `subtypes(AbstractRetriever)` to find the available options.
- If `return_all` is `false`, returns the generated message (`msg`).
- If `return_all` is `true`, returns the detail of the full pipeline in `RAGResult` (see the docs).
See also `build_index`, `retrieve`, `generate!`, `RAGResult`
See also `build_index`, `retrieve`, `generate!`, `RAGResult`, `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.
# Examples
Expand Down Expand Up @@ -498,6 +611,8 @@ kwargs = (
result = airag(cfg, index, question; kwargs...)
```
For easier manipulation of nested kwargs, see utilities `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.
"""
function airag(cfg::AbstractRAGConfig, index::AbstractChunkIndex;
question::AbstractString,
Expand Down
75 changes: 75 additions & 0 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ Default embedder for `get_embeddings` functions. It passes individual documents
"""
struct BatchEmbedder <: AbstractEmbedder end

"""
BinaryBatchEmbedder <: AbstractEmbedder
Same as `BatchEmbedder` but reduces the embeddings matrix to binary tool (eg, `BitMatrix`).
Reference: [HuggingFace: Embedding Quantization](https://huggingface.co/blog/embedding-quantization#binary-quantization-in-vector-databases).
"""
struct BinaryBatchEmbedder <: AbstractEmbedder end

EmbedderEltype(::T) where {T} = EmbedderEltype(T)
EmbedderEltype(::Type{<:AbstractEmbedder}) = Float32
EmbedderEltype(::Type{BinaryBatchEmbedder}) = Bool

### Tagging Types
"""
NoTagger <: AbstractTagger
Expand Down Expand Up @@ -165,6 +178,8 @@ end
"""
get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
model::AbstractString = PT.MODEL_EMBEDDING,
truncate_dimension::Union{Int, Nothing} = nothing,
cost_tracker = Threads.Atomic{Float64}(0.0),
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
Expand All @@ -185,6 +200,7 @@ Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched
- `docs`: A vector of strings to be embedded.
- `verbose`: A boolean flag for verbose output. Default is `true`.
- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`.
- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`.
- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call.
- `target_batch_size_length`: The target length (in characters) of each batch of document chunks sent for embedding. Default is 80_000 characters. Speeds up embedding process.
- `ntasks`: The number of tasks to use for asyncmap. Default is 4 * Threads.nthreads().
Expand All @@ -193,6 +209,7 @@ Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched
function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
model::AbstractString = PT.MODEL_EMBEDDING,
truncate_dimension::Union{Int, Nothing} = nothing,
cost_tracker = Threads.Atomic{Float64}(0.0),
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
Expand Down Expand Up @@ -220,10 +237,68 @@ function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:Abstract
msg.content
end
embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents
if !isnothing(truncate_dimension)
@assert truncate_dimension>0 "Truncated dimensionality must be non-negative (Provided: $(truncate_dimension))"
@assert truncate_dimension<=size(embeddings, 1) "Requested embeddings dimensionality is too high (Embeddings: $(size(embeddings)) vs dimensionality requested: $(truncate_dimension))"
## reduce + normalize again
embeddings = embeddings[1:truncate_dimension, :]
for i in axes(embeddings, 2)
embeddings[:, i] = _normalize(embeddings[:, i])
end
end
verbose && @info "Done embedding. Total cost: \$$(round(cost_tracker[],digits=3))"
return embeddings
end

"""
get_embeddings(embedder::BinaryBatchEmbedder, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
model::AbstractString = PT.MODEL_EMBEDDING,
truncate_dimension::Union{Int, Nothing} = nothing,
return_type::Type = BitMatrix,
cost_tracker = Threads.Atomic{Float64}(0.0),
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
kwargs...)
Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched manner and then returns the binary embeddings matrix - `BinaryBatchEmbedder`.
`BinaryBatchEmbedder` tries to batch embedding calls for roughly 80K characters per call (to avoid exceeding the API rate limit) to reduce network latency.
# Notes
- `docs` are assumed to be already chunked to the reasonable sizes that fit within the embedding context limit.
- If you get errors about exceeding input sizes, first check the `max_length` in your chunks.
If that does NOT resolve the issue, try reducing the `target_batch_size_length` parameter (eg, 10_000) and number of tasks `ntasks=1`.
Some providers cannot handle large batch sizes.
# Arguments
- `docs`: A vector of strings to be embedded.
- `verbose`: A boolean flag for verbose output. Default is `true`.
- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`.
- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`.
- `return_type`: The type of the returned embeddings matrix. Default is `BitMatrix`. Choose `BitMatrix` to minimize storage requirements, `Matrix{Bool}` to maximize performance in elementwise-ops.
- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call.
- `target_batch_size_length`: The target length (in characters) of each batch of document chunks sent for embedding. Default is 80_000 characters. Speeds up embedding process.
- `ntasks`: The number of tasks to use for asyncmap. Default is 4 * Threads.nthreads().
"""
function get_embeddings(
embedder::BinaryBatchEmbedder, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
model::AbstractString = PT.MODEL_EMBEDDING,
truncate_dimension::Union{Int, Nothing} = nothing,
return_type::Type = BitMatrix,
cost_tracker = Threads.Atomic{Float64}(0.0),
target_batch_size_length::Int = 80_000,
ntasks::Int = 4 * Threads.nthreads(),
kwargs...)
emb = get_embeddings(BatchEmbedder(), docs; verbose, model, truncate_dimension,
cost_tracker, target_batch_size_length, ntasks, kwargs...)
# This will return BitMatrix to save space, for best performance use Matrix{Bool}, eg, map(>(0),emb)
emb = (emb .> 0) |> x -> x isa return_type ? x : return_type(x)
end

### Tag Extraction

function get_tags(tagger::AbstractTagger, docs::AbstractVector{<:AbstractString};
Expand Down
24 changes: 24 additions & 0 deletions src/Experimental/RAGTools/rag_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@
abstract type AbstractRAGConfig end

# supertype for RAGDetails, return_type for retrieve and generate (and optionally airag)
"""
AbstractRAGResult
Abstract type for the result of the RAG (Retrieval-Augmented Generation) process.
Implementations of this type should contain the necessary fields to represent the outcome of the RAG pipeline, including the original question, any rephrased versions of the question, the generated answer, and any additional context or metadata used or generated during the process.
# Fields
- [OPTIONAL] `question::AbstractString`: The original question posed to the RAG system.
- `rephrased_questions::AbstractVector{<:AbstractString}`: A vector of rephrased versions of the original question, generated during the retrieval phase to improve the quality of the results.
- [OPTIONAL] `answer::Union{Nothing, AbstractString}`: The initial answer generated based on the retrieved information and the question. This field may be `nothing` if the generation phase has not yet produced an answer.
- `final_answer::Union{Nothing, AbstractString}`: The final refined answer after any post-processing steps have been applied. This is considered the definitive answer produced by the RAG system.
- `context::Vector{<:AbstractString}`: A vector of strings representing the context used for generating the answer. This may include relevant information retrieved during the retrieval phase.
- `sources::Vector{<:AbstractString}`: The sources of the context information, providing traceability for the data used in generating the answer.
... some fields for search candidates (`::CandidateChunks`)
- [OPTIONAL] `conversations::Dict{Symbol,Vector{<:AbstractMessage}}`: A dictionary containing the history of AI-generated messages and interactions during the RAG process. Keys correspond to the names of functions in the RAG pipeline, providing insight into the decision-making process at each step.
If `rephrased_questions` is the primarily field, it should be used instead of `question`.
If `final_answer` is the primarily field, it should be used instead of `answer`.
`conversations` recording is optional but highly recommended for observability.
This abstract type serves as a blueprint for concrete implementations that store the results of the RAG process, facilitating debugging, analysis, and further processing of the generated answers.
"""

abstract type AbstractRAGResult end

# ## Preparation Stage
Expand Down
Loading

2 comments on commit 46f6770

@svilupp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

  • Added a few new open-weights models hosted by Fireworks.ai to the registry (DBRX Instruct, Mixtral 8x22b Instruct, Qwen 72b). If you're curious about how well they work, try them!
  • Added basic support for observability downstream. Created custom callback infrastructure with initialize_tracer and finalize_tracer and dedicated types are TracerMessage and TracerMessageLike. See ?TracerMessage for more information and the corresponding aigenerate docstring.
  • Added MultiCandidateChunks which can hold candidates for retrieval across many indices (it's a flat structure to be similar to CandidateChunks and easy to reason about).
  • JSON serialization support extended for RAGResult, CandidateChunks, and MultiCandidateChunks to increase observability of RAG systems
  • Added a new search refiner TavilySearchRefiner - it will search the web via Tavily API to try to improve on the RAG answer (see ?refine!).
  • Introduced a few small utilities for manipulation of nested kwargs (necessary for RAG pipelines), check out getpropertynested, setpropertynested, merge_kwargs_nested.

Updated

  • [BREAKING] change to CandidateChunks where it's no longer allowed to be nested (ie, cc.positions being a list of several CandidateChunks). This is a breaking change for the RAGTools module only. We have introduced a new MultiCandidateChunks types that can refer to CandidateChunks across many indices.
  • Changed default model for RAGTools.CohereReranker to "cohere-rerank-english-v3.0".

Fixed

  • wrap_string utility now correctly splits only on spaces. Previously it would split on newlines, which would remove natural formatting of prompts/messages when displayed via pprint

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105122

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.20.0 -m "<description of version>" 46f67704c488b1665f1d0621da2e3d23ee16fef0
git push origin v0.20.0

Please sign in to comment.