Skip to content

Commit

Permalink
Update RAG performance
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Jun 18, 2024
1 parent c5ac64f commit eff6682
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 61 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.31.0]

### Breaking Changes
- The return type of `RAGTools.find_tags(::NoTagger,...)` is now `::Nothing` instead of `CandidateChunks`/`MultiCandidateChunks` with all documents.
- `Base.getindex(::MultiIndex, ::MultiCandidateChunks)` now always returns sorted chunks for consistency with the behavior of other `getindex` methods on `*Chunks`.

### Updated
- Cosine similarity search now uses `partialsortperm` for better performance on large datasets.
- Skip unnecessary work when the tagging functionality in the RAG pipeline is disabled (`find_tags` with `NoTagger` always returns `nothing` which improves the compiled code).
- Changed the default behavior of `getindex(::MultiIndex, ::MultiCandidateChunks)` to always return sorted chunks for consistency with other similar functions. Note that you should always use re-rankering anyway (see `FlashRank.jl`).

## [0.30.0]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.30.0"
version = "0.31.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
59 changes: 25 additions & 34 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,14 @@ function find_closest(
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
# emb is an embedding matrix where the first dimension is the embedding dimension
scores = query_emb' * emb |> vec
positions = scores |> sortperm |> x -> last(x, top_k) |> reverse
top_k_min = min(top_k, length(scores))
positions = partialsortperm(scores, 1:top_k_min, rev = true)
if minimum_similarity > -1.0
mask = scores[positions] .>= minimum_similarity
positions = positions[mask]
else
## we want to materialize the view
positions = collect(positions)
end
return positions, scores[positions]
end
Expand Down Expand Up @@ -374,7 +378,8 @@ function find_closest(
## First pass, both in binary with Hamming, get rescore_multiplier times top_k
binary_query_emb = map(>(0), query_emb)
scores = hamming_distance(emb, binary_query_emb)
positions = scores |> sortperm |> x -> first(x, top_k * rescore_multiplier)
num_candidates = min(top_k * rescore_multiplier, length(scores))
positions = partialsortperm(scores, 1:num_candidates)

## Second pass, rescore with float embeddings and return top_k
new_positions, scores = find_closest(CosineSimilarity(), @view(emb[:, positions]),
Expand Down Expand Up @@ -415,7 +420,8 @@ function find_closest(
## First pass, both in binary with Hamming, get rescore_multiplier times top_k
bit_query_emb = pack_bits(query_emb .> 0)
scores = hamming_distance(emb, bit_query_emb)
positions = scores |> sortperm |> x -> first(x, top_k * rescore_multiplier)
num_candidates = min(top_k * rescore_multiplier, length(scores))
positions = partialsortperm(scores, 1:num_candidates)

## Second pass, rescore with float embeddings and return top_k
unpacked_emb = unpack_bits(@view(emb[:, positions]))
Expand All @@ -442,11 +448,15 @@ function find_closest(
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
bm_scores = bm25(dtm, query_tokens)
positions = bm_scores |> sortperm |> x -> last(x, top_k) |> reverse
top_k_min = min(top_k, length(bm_scores))
positions = partialsortperm(bm_scores, 1:top_k_min, rev = true)

if minimum_similarity > -1.0
mask = scores[positions] .>= minimum_similarity
positions = positions[mask]
else
# materialize the vector
positions = positions |> collect
end
return positions, bm_scores[positions]
end
Expand All @@ -455,7 +465,8 @@ end

function find_tags(::AbstractTagFilter, index::AbstractDocumentIndex,
tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{AbstractString, Regex}}
Union{
AbstractString, Regex, Nothing}}
throw(ArgumentError("Not implemented yet for type $(typeof(filter)) and index $(typeof(index))"))
end

Expand Down Expand Up @@ -492,24 +503,19 @@ end

"""
find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex, Nothing}}
tags; kwargs...)
Returns all chunks in the index, ie, no filtering.
Returns all chunks in the index, ie, no filtering, so we simply return `nothing` (easier for dispatch).
"""
function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex}}
return CandidateChunks(
index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks)))
AbstractString, Regex, Nothing}}
return nothing
end

function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Nothing; kwargs...)
return CandidateChunks(
index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks)))
end

## Multi-index implementation
function find_tags(method::AnyTagFilter, index::AbstractMultiIndex,
tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Expand Down Expand Up @@ -539,24 +545,8 @@ end
function find_tags(method::NoTagFilter, index::AbstractMultiIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex}}
indexes_ = indexes(index)
length_ = sum(x -> length(x.chunks), indexes_)
index_ids = [fill(x.id, length(x.chunks)) for x in indexes_] |> Base.Splat(vcat)

return MultiCandidateChunks(
index_ids, collect(1:length_),
zeros(Float32, length_))
end
function find_tags(method::NoTagFilter, index::AbstractMultiIndex,
tags::Nothing; kwargs...)
indexes_ = indexes(index)
length_ = sum(x -> length(x.chunks), indexes_)
index_ids = [fill(x.id, length(x.chunks)) for x in indexes_] |> Base.Splat(vcat)

return MultiCandidateChunks(
index_ids, collect(1:length_),
zeros(Float32, length_))
AbstractString, Regex, Nothing}}
return nothing
end

### Reranking
Expand Down Expand Up @@ -966,6 +956,7 @@ function retrieve(retriever::AbstractRetriever,
filter, index, tags; verbose = (verbose > 1), filter_kwargs_...)

## Combine the two sets of candidates, looks for intersection (hard filter)!
# With tagger=NoTagger() get_tags returns `nothing` find_tags simply passes it through to skip the intersection
filtered_candidates = isnothing(tag_candidates) ? emb_candidates :
(emb_candidates & tag_candidates)
## TODO: Future implementation should be to apply tag filtering BEFORE the find_closest,
Expand Down
7 changes: 4 additions & 3 deletions src/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ function Base.getindex(ci::AbstractDocumentIndex,
end
function Base.getindex(ci::AbstractChunkIndex,
candidate::CandidateChunks{TP, TD},
field::Symbol = :chunks) where {TP <: Integer, TD <: Real}
field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real}
@assert field in [:chunks, :embeddings, :chunkdata, :sources] "Only `chunks`, `embeddings`, `chunkdata`, `sources` fields are supported for now"
field = field == :embeddings ? :chunkdata : field
len_ = length(chunks(ci))
Expand All @@ -504,7 +504,8 @@ function Base.getindex(ci::AbstractChunkIndex,
end
function Base.getindex(mi::MultiIndex,
candidate::CandidateChunks{TP, TD},
field::Symbol = :chunks) where {TP <: Integer, TD <: Real}
field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real}
## Always sorted!
@assert field in [:chunks, :sources] "Only `chunks`, `sources` fields are supported for now"
valid_index = findfirst(x -> x.id == candidate.index_id, indexes(mi))
if isnothing(valid_index) && field == :chunks
Expand Down Expand Up @@ -549,7 +550,7 @@ end
# Getindex on Multiindex, pool the individual hits
function Base.getindex(mi::MultiIndex,
candidate::MultiCandidateChunks{TP, TD},
field::Symbol = :chunks; sorted::Bool = false) where {TP <: Integer, TD <: Real}
field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real}
@assert field in [:chunks, :sources, :scores] "Only `chunks`, `sources`, and `scores` fields are supported for now"
if sorted
# values can be either of chunks or sources
Expand Down
3 changes: 2 additions & 1 deletion test/Experimental/RAGTools/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ end
embeddings = zeros(128, 3),
tags = vcat(trues(2, 2), falses(1, 2)),
tags_vocab = ["yes", "no"])
index.embeddings[1, 1] = 1

# Test for successful Q&A extraction from document chunks
qa_evals = build_qa_evals(chunks(index),
Expand Down Expand Up @@ -193,7 +194,7 @@ end
api_kwargs = (; url = "http://localhost:$(PORT)"),
parameters_dict = Dict(:key1 => "value1", :key2 => 2))
@test result.retrieval_score == 1.0
@test result.retrieval_rank == 2
@test result.retrieval_rank == 1
@test result.answer_score == 5
@test result.parameters == Dict(:key1 => "value1", :key2 => 2)

Expand Down
74 changes: 53 additions & 21 deletions test/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,14 @@ end

# No filter tag -- give everything
cc = find_tags(NoTagFilter(), index, "julia")
@test cc.positions == [1, 2]
@test cc.scores == [0.0, 0.0]
@test isnothing(cc)
# @test cc.positions == [1, 2]
# @test cc.scores == [0.0, 0.0]

cc = find_tags(NoTagFilter(), index, nothing)
@test cc.positions == [1, 2]
@test cc.scores == [0.0, 0.0]
@test isnothing(cc)
# @test cc.positions == [1, 2]
# @test cc.scores == [0.0, 0.0]

# Unknown type
struct RandomTagFilter123 <: AbstractTagFilter end
Expand All @@ -456,12 +458,14 @@ end
multi_index = MultiIndex(id = :multi, indexes = [index1, index2])

mcc = find_tags(NoTagFilter(), multi_index, "julia")
@test mcc.positions == [1, 2, 3, 4]
@test mcc.scores == [0.0, 0.0, 0.0, 0.0]
@test mcc == nothing
# @test mcc.positions == [1, 2, 3, 4]
# @test mcc.scores == [0.0, 0.0, 0.0, 0.0]

mcc = find_tags(NoTagFilter(), multi_index, nothing)
@test mcc.positions == [1, 2, 3, 4]
@test mcc.scores == [0.0, 0.0, 0.0, 0.0]
@test mcc == nothing
# @test mcc.positions == [1, 2, 3, 4]
# @test mcc.scores == [0.0, 0.0, 0.0, 0.0]

multi_index2 = MultiIndex(id = :multi2, indexes = [index, index2])
mcc2 = find_tags(AnyTagFilter(), multi_index2, "julia")
Expand Down Expand Up @@ -538,7 +542,7 @@ end

@testset "retrieve" begin
# test with a mock server
PORT = rand(20000:40000)
PORT = rand(20000:40001)
PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema())
PT.register_model!(; name = "mock-emb2", schema = PT.CustomOpenAISchema())
PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema())
Expand Down Expand Up @@ -609,8 +613,16 @@ end
@test result.rephrased_questions == [question]
@test result.answer == nothing
@test result.final_answer == nothing
@test result.reranked_candidates.positions == [2, 1, 4, 3]
@test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"]
## there are two equivalent orderings
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])
@test Set(result.reranked_candidates.positions[3:4]) == Set([3, 4])
@test result.reranked_candidates.scores[1:2] == ones(2)
@test length(result.context) == 4
@test length(unique(result.context)) == 4
@test result.context[1] in ["chunk2", "chunk1"]
@test result.context[2] in ["chunk2", "chunk1"]
@test result.context[3] in ["chunk3", "chunk4"]
@test result.context[4] in ["chunk3", "chunk4"]
@test result.sources isa Vector{String}

# Reduce number of candidates
Expand All @@ -620,8 +632,10 @@ end
embedder_kwargs = (; model = "mock-emb"),
tagger_kwargs = (; model = "mock-meta"), api_kwargs = (;
url = "http://localhost:$(PORT)"))
@test result.emb_candidates.positions == [2, 1, 4]
@test result.reranked_candidates.positions == [2, 1]
## the last item is 3 or 4
@test result.emb_candidates.positions[3] in [3, 4]
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])
@test result.emb_candidates.scores[1:2] == ones(2)

# with default dispatch
result = retrieve(index, question;
Expand All @@ -630,8 +644,9 @@ end
embedder_kwargs = (; model = "mock-emb"),
tagger_kwargs = (; model = "mock-meta"), api_kwargs = (;
url = "http://localhost:$(PORT)"))
@test result.emb_candidates.positions == [2, 1, 4]
@test result.reranked_candidates.positions == [2, 1]
@test result.emb_candidates.positions[3] in [3, 4]
@test result.emb_candidates.scores[1:2] == ones(2)
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])

## AdvancedRetriever
adv = AdvancedRetriever()
Expand All @@ -645,20 +660,29 @@ end
@test result.rephrased_questions == [question, "Query: test question\n\nPassage:"] # from the template we use
@test result.answer == nothing
@test result.final_answer == nothing
@test result.reranked_candidates.positions == [2, 1, 4, 3]
@test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"]
## there are two equivalent orderings
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])
@test Set(result.reranked_candidates.positions[3:4]) == Set([3, 4])
@test result.reranked_candidates.scores[1:2] == ones(2)
@test length(result.context) == 4
@test length(unique(result.context)) == 4
@test result.context[1] in ["chunk2", "chunk1"]
@test result.context[2] in ["chunk2", "chunk1"]
@test result.context[3] in ["chunk3", "chunk4"]
@test result.context[4] in ["chunk3", "chunk4"]
@test result.sources isa Vector{String}

# Multi-index retriever
index_keywords = ChunkKeywordsIndex(index, index_id = :TestChunkIndexX)
index_keywords = ChunkIndex(; id = :AA, index.chunks, index.sources, index.embeddings)
# Create MultiIndex instance
multi_index = MultiIndex(id = :multi, indexes = [index, index_keywords])

# Create MultiFinder instance
finder = MultiFinder([RT.CosineSimilarity(), RT.BM25Similarity()])

retriever = SimpleRetriever(; processor = RT.KeywordsProcessor(), finder)
result = retrieve(retriever, multi_index, question;
result = retrieve(SimpleRetriever(), multi_index, question;
reranker = NoReranker(), # we need to disable cohere as we cannot test it
rephraser_kwargs = (; model = "mock-gen"),
embedder_kwargs = (; model = "mock-emb"),
Expand All @@ -668,9 +692,17 @@ end
@test result.rephrased_questions == [question]
@test result.answer == nothing
@test result.final_answer == nothing
@test result.reranked_candidates.positions == [2, 1, 4, 3]
@test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"]
@test result.sources == ["source2", "source1", "source4", "source3"]
## there are two equivalent orderings
@test Set(result.reranked_candidates.positions[1:4]) == Set([2, 1])
@test result.reranked_candidates.positions[5] in [3, 4]
@test result.reranked_candidates.scores[1:4] == ones(4)
@test length(result.context) == 5 # because the second index duplicates, so we have more
@test length(unique(result.context)) == 3 # only 3 unique chunks because 1,2,1,2,3
@test all([result.context[i] in ["chunk2", "chunk1"] for i in 1:4])
@test result.context[5] in ["chunk3", "chunk4"]
@test length(unique(result.sources)) == 3
@test all([result.sources[i] in ["source2", "source1"] for i in 1:4])
@test result.sources[5] in ["source3", "source4"]

# clean up
close(echo_server)
Expand Down
7 changes: 6 additions & 1 deletion test/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,16 @@ end
index_ids = [Symbol("TestChunkIndex"), Symbol("TestChunkIndex2")],
positions = [1, 3], # Assuming chunks_data has only 3 elements, position 4 is out of bounds
scores = [0.5, 0.7])
@test mi[mc1] == ["First chunk", "6"]
## sorted=true by default
@test mi[mc1] == ["6", "First chunk"]
@test Base.getindex(mi, mc1, :chunks; sorted = true) == ["6", "First chunk"]
@test Base.getindex(mi, mc1, :sources; sorted = true) ==
["other_source3", "test_source1"]
@test Base.getindex(mi, mc1, :scores; sorted = true) == [0.7, 0.5]
@test Base.getindex(mi, mc1, :chunks; sorted = false) == ["First chunk", "6"]
@test Base.getindex(mi, mc1, :sources; sorted = false) ==
["test_source1", "other_source3"]
@test Base.getindex(mi, mc1, :scores; sorted = false) == [0.5, 0.7]
end

@testset "RAGResult" begin
Expand Down

2 comments on commit eff6682

@svilupp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Breaking Changes

  • The return type of RAGTools.find_tags(::NoTagger,...) is now ::Nothing instead of CandidateChunks/MultiCandidateChunks with all documents.
  • Base.getindex(::MultiIndex, ::MultiCandidateChunks) now always returns sorted chunks for consistency with the behavior of other getindex methods on *Chunks.

Updated

  • Cosine similarity search now uses partialsortperm for better performance on large datasets.
  • Skip unnecessary work when the tagging functionality in the RAG pipeline is disabled (find_tags with NoTagger always returns nothing which improves the compiled code).
  • Changed the default behavior of getindex(::MultiIndex, ::MultiCandidateChunks) to always return sorted chunks for consistency with other similar functions. Note that you should always use re-rankering anyway (see FlashRank.jl).

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/109243

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.31.0 -m "<description of version>" eff668260c85b7e5860a506f45f2687a251de8f2
git push origin v0.31.0

Please sign in to comment.