Skip to content

Commit

Permalink
Add BM25 Index (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed May 28, 2024
1 parent af8b238 commit 8fd485a
Show file tree
Hide file tree
Showing 18 changed files with 1,549 additions and 201 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.27.0]

### Added
- Added a keyword-based search similarity to RAGTools to serve both for baseline evaluation and for advanced performance (by having a hybrid index with both embeddings and BM25). Works only for retrieval for now (no E2E `airag` support yet) See `?RT.KeywordsIndexer` and `?RT.BM25Similarity` for more information, to build use `build_index(KeywordsIndexer(), texts)` or convert an existing embeddings-based index `ChunkKeywordsIndex(index)`.

### Updated
- For naming consistency, `ChunkIndex` in RAGTools has been renamed to `ChunkEmbeddingsIndex` (with an alias `ChunkIndex` for backwards compatibility). There are now two main index types: `ChunkEmbeddingsIndex` and `ChunkKeywordsIndex` (=BM25), which can be combined into a `MultiIndex` to serve as a hybrid index.

## [0.26.2]

### Fixed
Expand Down
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.26.2"
version = "0.27.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand All @@ -21,12 +21,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
GoogleGenAI = "903d41d1-eaca-47dd-943b-fee3930375ab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[extensions]
GoogleGenAIPromptingToolsExt = ["GoogleGenAI"]
MarkdownPromptingToolsExt = ["Markdown"]
RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra"]
RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra", "Unicode"]
SnowballPromptingToolsExt = ["Snowball"]

[compat]
AbstractTrees = "0.4"
Expand Down Expand Up @@ -56,4 +59,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[targets]
test = ["Aqua", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown"]
test = ["Aqua", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192"
Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
Expand Down
99 changes: 96 additions & 3 deletions ext/RAGToolsExperimentalExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
module RAGToolsExperimentalExt

using PromptingTools, SparseArrays
using LinearAlgebra: normalize
using PromptingTools, SparseArrays, Unicode
using LinearAlgebra
const PT = PromptingTools

using PromptingTools.Experimental.RAGTools
const RT = PromptingTools.Experimental.RAGTools

# forward to LinearAlgebra.normalize
RT._normalize(arr::AbstractArray) = normalize(arr)
RT._normalize(arr::AbstractArray) = LinearAlgebra.normalize(arr)

# Forward to Unicode.normalize
function RT._unicode_normalize(text::AbstractString; kwargs...)
Unicode.normalize(text; kwargs...)
end

"""
RT.build_tags(
Expand Down Expand Up @@ -40,4 +45,92 @@ function RT.build_tags(
return tags_, tags_vocab_
end

"""
document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
Builds a sparse matrix of term frequencies and document lengths from the given vector of documents wrapped in type `DocumentTermMatrix`.
Expects a vector of preprocessed (tokenized) documents, where each document is a vector of strings (clean tokens).
Returns: `DocumentTermMatrix`
# Example
```
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
```
"""
function RT.document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
T = eltype(documents) |> eltype
vocab = convert(Vector{T}, unique(vcat(documents...)))
vocab_lookup = Dict{T, Int}(t => i for (i, t) in enumerate(vocab))
N = length(documents)
doc_freq = zeros(Int, length(vocab))
term_freq = spzeros(Float32, N, length(vocab))
doc_lengths = zeros(Float32, N)
for di in eachindex(documents)
unique_terms = Set{eltype(vocab)}()
doc = documents[di]
for t in doc
doc_lengths[di] += 1
tid = vocab_lookup[t]
term_freq[di, tid] += 1
if !(t in unique_terms)
doc_freq[tid] += 1
push!(unique_terms, t)
end
end
end
idf = @. log(1.0f0 + (N - doc_freq + 0.5f0) / (doc_freq + 0.5f0))
sumdl = sum(doc_lengths)
doc_rel_length = sumdl == 0 ? zeros(Float32, N) : doc_lengths ./ (sumdl / N)
RT.DocumentTermMatrix(term_freq, vocab, vocab_lookup, idf, doc_rel_length)
end

function RT.document_term_matrix(documents::AbstractVector{<:AbstractString})
RT.document_term_matrix(RT.preprocess_tokens(documents))
end

"""
RT.bm25(dtm::DocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0)
Scores all documents in `dtm` based on the `query`.
References: https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/
# Example
```
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
query = ["this"]
scores = bm25(dtm, query)
# Returns array with 3 scores (one for each document)
```
"""
function RT.bm25(dtm::RT.DocumentTermMatrix, query::AbstractVector{<:AbstractString};
k1::Float32 = 1.2f0, b::Float32 = 0.75f0)
scores = zeros(Float32, size(dtm.tf, 1))
## Identify non-zero items to leverage the sparsity
nz_rows = rowvals(dtm.tf)
nz_vals = nonzeros(dtm.tf)
for i in eachindex(query)
t = query[i]
t_id = get(dtm.vocab_lookup, t, nothing)
t_id === nothing && continue
idf = dtm.idf[t_id]
# Scan only documents that have this token
@inbounds @simd for j in nzrange(dtm.tf, t_id)
## index into the sparse matrix
di, tf = nz_rows[j], nz_vals[j]
doc_len = dtm.doc_rel_length[di]
tf_top = (tf * (k1 + 1.0f0))
tf_bottom = (tf + k1 * (1.0f0 - b + b * doc_len))
score = idf * tf_top / tf_bottom
## @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score"
scores[di] += score
end
end
scores
end

end # end of module
62 changes: 62 additions & 0 deletions ext/SnowballPromptingToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module SnowballPromptingToolsExt

using PromptingTools
const PT = PromptingTools

using PromptingTools.Experimental.RAGTools
const RT = PromptingTools.Experimental.RAGTools

using Snowball

# forward to Stemmer.stem
RT._stem(stemmer::Snowball.Stemmer, text::AbstractString) = Snowball.stem(stemmer, text)

"""
get_keywords(processor::KeywordsProcessor, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
stemmer = nothing,
stopwords::Set{String} = Set(STOPWORDS),
return_keywords::Bool = false,
kwargs...)
Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stemmer` and `stopwords`.
# Arguments
- `docs`: A vector of strings to be embedded.
- `verbose`: A boolean flag for verbose output. Default is `true`.
- `stemmer`: A stemmer to use for stemming. Default is `nothing`.
- `stopwords`: A set of stopwords to remove. Default is `Set(STOPWORDS)`.
- `return_keywords`: A boolean flag for returning the keywords. Default is `false`. Useful for query processing in search time.
"""
function RT.get_keywords(
processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
stemmer = nothing,
stopwords::Set{String} = Set(RT.STOPWORDS),
return_keywords::Bool = false,
kwargs...)
## check if extension is available
ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt)
if isnothing(ext)
error("You need to also import LinearAlgebra and SparseArrays to use this function")
end
## ext = Base.get_extension(PromptingTools, :SnowballPromptingToolsExt)
## if isnothing(ext)
## error("You need to also import Snowball.jl to use this function")
## end
## Preprocess text into tokens
stemmer = !isnothing(stemmer) ? stemmer : Snowball.Stemmer("english")
# Single-threaded as stemmer is not thread-safe
keywords = RT.preprocess_tokens(docs, stemmer; stopwords, min_length = 3)

## Early exit if we only want keywords (search time)
return_keywords && return keywords

## Create DTM
dtm = RT.document_term_matrix(keywords)

verbose && @info "Done processing DocumentTermMatrix."
return dtm
end

end # end of module
9 changes: 5 additions & 4 deletions src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ include("api_services.jl")

include("rag_interface.jl")

export ChunkIndex, CandidateChunks, RAGResult
# export MultiIndex # not ready yet
export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, CandidateChunks, RAGResult
export MultiIndex
include("types.jl")

export build_index, get_chunks, get_embeddings, get_tags
export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIndexer,
KeywordsIndexer
include("preparation.jl")

export retrieve, SimpleRetriever, AdvancedRetriever
export retrieve, SimpleRetriever, SimpleBM25Retriever, AdvancedRetriever
export find_closest, find_tags, rerank, rephrase
include("retrieval.jl")

Expand Down
Loading

0 comments on commit 8fd485a

Please sign in to comment.