Skip to content

Commit

Permalink
Add RankGPT (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Jul 1, 2024
1 parent b6d10f5 commit 751f337
Show file tree
Hide file tree
Showing 10 changed files with 677 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.34.0]

### Added
- `RankGPT` implementation for RAGTools chunk re-ranking pipeline. See `?RAGTools.Experimental.rank_gpt` for more information and corresponding reranker type `?RankGPTReranker`.

## [0.33.2]

### Fixed
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.33.2"
version = "0.34.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -62,6 +62,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[targets]
test = ["Aqua", "FlashRank", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"]
test = ["Aqua", "FlashRank", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball", "Unicode"]
2 changes: 1 addition & 1 deletion ext/FlashRankPromptingToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function RT.rerank(
kwargs...)
@assert top_n>0 "top_n must be a positive integer."
documents = index[candidates, :chunks]
@assert !(isempty(documents)) "The candidate chunks must not be empty for Cohere Reranker! Check the index IDs."
@assert !(isempty(documents)) "The candidate chunks must not be empty! Check the index IDs."

is_multi_cand = candidates isa RT.MultiCandidateChunks
index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id
Expand Down
2 changes: 2 additions & 0 deletions src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIn
KeywordsIndexer
include("preparation.jl")

include("rank_gpt.jl")

export retrieve, SimpleRetriever, SimpleBM25Retriever, AdvancedRetriever
export find_closest, find_tags, rerank, rephrase
include("retrieval.jl")
Expand Down
166 changes: 166 additions & 0 deletions src/Experimental/RAGTools/rank_gpt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Implementation of RankGPT
# Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al. // https://arxiv.org/abs/2304.09542
# https://github.com/sunnweiwei/RankGPT

"""
RankGPTResult
Results from the RankGPT algorithm.
# Fields
- `question::String`: The question that was asked.
- `chunks::AbstractVector{T}`: The chunks that were ranked (=context).
- `positions::Vector{Int}`: The ranking of the chunks (referring to the `chunks`).
- `elapsed::Float64`: The time it took to rank the chunks.
- `cost::Float64`: The cumulative cost of the ranking.
- `tokens::Int`: The cumulative number of tokens used in the ranking.
"""
@kwdef mutable struct RankGPTResult{T <: AbstractString}
question::String
chunks::AbstractVector{T}
positions::Vector{Int} = collect(1:length(chunks))
elapsed::Float64 = 0.0
cost::Float64 = 0.0
tokens::Int = 0
end
Base.show(io::IO, result::RankGPTResult) = dump(io, result; maxdepth = 1)

"""
create_permutation_instruction(
context::AbstractVector{<:AbstractString}; rank_start::Integer = 1,
rank_end::Integer = 100, max_length::Integer = 512, template::Symbol = :RAGRankGPT)
Creates rendered template with injected `context` passages.
"""
function create_permutation_instruction(
context::AbstractVector{<:AbstractString}; rank_start::Integer = 1,
rank_end::Integer = 100, max_length::Integer = 512, template::Symbol = :RAGRankGPT)
##
rank_end_adj = min(rank_end, length(context))
num = rank_end_adj - rank_start + 1

messages = PT.render(PT.AITemplate(template))
last_msg = pop!(messages)
rank = 0
for ctx in context[rank_start:rank_end_adj]
rank += 1
push!(messages, PT.UserMessage("[$rank] $(strip(ctx)[1:min(end, max_length)])"))
push!(messages, PT.AIMessage("Received passage [$rank]."))
end
push!(messages, last_msg)

return messages, num
end

"""
extract_ranking(str::AbstractString)
Extracts the ranking from the response into a sorted array of integers.
"""
function extract_ranking(str::AbstractString)
nums = replace(str, r"[^0-9]" => " ") |> strip |> split
nums = parse.(Int, nums)
unique_idxs = unique(i -> nums[i], eachindex(nums))
return nums[unique_idxs]
end

"""
receive_permutation!(
curr_rank::AbstractVector{<:Integer}, response::AbstractString;
rank_start::Integer = 1, rank_end::Integer = 100)
Extracts and heals the permutation to contain all ranking positions.
"""
function receive_permutation!(
curr_rank::AbstractVector{<:Integer}, response::AbstractString;
rank_start::Integer = 1, rank_end::Integer = 100)
@assert rank_start>=1 "rank_start must be greater than or equal to 1"
@assert rank_end>=rank_start "rank_end must be greater than or equal to rank_start"
new_rank = extract_ranking(response)
copied_rank = curr_rank[rank_start:min(end, rank_end)] |> copy
orig_rank = 1:length(copied_rank)
new_rank = vcat(
[r for r in new_rank if r in orig_rank], [r for r in orig_rank if r new_rank])
for (j, rnk) in enumerate(new_rank)
curr_rank[rank_start + j - 1] = copied_rank[rnk]
end
return curr_rank
end

"""
permutation_step!(
result::RankGPTResult; rank_start::Integer = 1, rank_end::Integer = 100, kwargs...)
One sub-step of the RankGPT algorithm permutation ranking within the window of chunks defined by `rank_start` and `rank_end` positions.
"""
function permutation_step!(
result::RankGPTResult; rank_start::Integer = 1, rank_end::Integer = 100, kwargs...)
(; positions, chunks, question) = result
tpl, num = create_permutation_instruction(chunks; rank_start, rank_end)
msg = aigenerate(tpl; question, num, kwargs...)
result.positions = receive_permutation!(
positions, PT.last_output(msg); rank_start, rank_end)
result.cost += msg.cost
result.tokens += sum(msg.tokens)
result.elapsed += msg.elapsed
return result
end

"""
rank_sliding_window!(
result::RankGPTResult; verbose::Int = 1, rank_start = 1, rank_end = 100,
window_size = 20, step = 10, model::String = "gpt4o", kwargs...)
One single pass of the RankGPT algorithm permutation ranking across all positions between `rank_start` and `rank_end`.
"""
function rank_sliding_window!(
result::RankGPTResult; verbose::Int = 1, rank_start = 1, rank_end = 100,
window_size = 20, step = 10, model::String = "gpt4o", kwargs...)
@assert rank_start>=0 "rank_start must be greater than or equal to 0 (Provided: rank_start=$rank_start)"
@assert rank_end>=rank_start "rank_end must be greater than or equal to rank_start (Provided: rank_end=$rank_end, rank_start=$rank_start)"
@assert rank_end>=window_size>=step "rank_end must be greater than or equal to window_size, which must be greater than or equal to step (Provided: rank_end=$rank_end, window_size=$window_size, step=$step)"
end_pos = min(rank_end, length(result.chunks))
start_pos = max(end_pos - window_size, 1)
while start_pos >= rank_start
(verbose >= 1) && @info "Ranking chunks in positions $start_pos to $end_pos"
permutation_step!(result; rank_start = start_pos, rank_end = end_pos,
model, verbose = (verbose >= 1), kwargs...)
(verbose >= 2) && @info "Current ranking: $(result.positions)"
end_pos -= step
start_pos -= step
end
return result
end

"""
rank_gpt(chunks::AbstractVector{<:AbstractString}, question::AbstractString;
verbose::Int = 1, rank_start::Integer = 1, rank_end::Integer = 100,
window_size::Integer = 20, step::Integer = 10,
num_rounds::Integer = 1, model::String = "gpt4o", kwargs...)
Ranks the `chunks` based on their relevance for `question`. Returns the ranking permutation of the chunks in the order they are most relevant to the question (the first is the most relevant).
# Example
```julia
result = rank_gpt(chunks, question; rank_start=1, rank_end=25, window_size=8, step=4, num_rounds=3, model="gpt4o")
```
# Reference
[1] [Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al.](https://arxiv.org/abs/2304.09542)
[2] [RankGPT Github](https://github.com/sunnweiwei/RankGPT)
"""
function rank_gpt(chunks::AbstractVector{<:AbstractString}, question::AbstractString;
verbose::Int = 1, rank_start::Integer = 1, rank_end::Integer = 100,
window_size::Integer = 20, step::Integer = 10,
num_rounds::Integer = 1, model::String = "gpt4o", kwargs...)
result = RankGPTResult(; question, chunks)
for i in 1:num_rounds
(verbose >= 1) && @info "Round $i of $num_rounds of ranking process."
result = rank_sliding_window!(
result; verbose = verbose - 1, rank_start, rank_end,
window_size, step, model, kwargs...)
end
(verbose >= 1) &&
@info "Final ranking done. Tokens: $(result.tokens), Cost: $(round(result.cost, digits=2)), Time: $(round(result.elapsed, digits=1))s"
return result
end
106 changes: 106 additions & 0 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,17 @@ struct FlashRanker{T} <: AbstractReranker
model::T
end

"""
RankGPTReranker <: AbstractReranker
Rerank strategy using the RankGPT algorithm (calling LLMs).
# Reference
[1] [Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al.](https://arxiv.org/abs/2304.09542)
[2] [RankGPT Github](https://github.com/sunnweiwei/RankGPT)
"""
struct RankGPTReranker <: AbstractReranker end

function rerank(reranker::AbstractReranker,
index::AbstractDocumentIndex, question::AbstractString, candidates::AbstractCandidateChunks; kwargs...)
throw(ArgumentError("Not implemented yet"))
Expand Down Expand Up @@ -697,6 +708,101 @@ function rerank(
CandidateChunks(index_ids, positions, scores)
end

"""
rerank(
reranker::CohereReranker, index::AbstractDocumentIndex, question::AbstractString,
candidates::AbstractCandidateChunks;
verbose::Integer = 1,
api_key::AbstractString = PT.OPENAI_API_KEY,
top_n::Integer = length(candidates.scores),
model::AbstractString = PT.MODEL_CHAT,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Re-ranks a list of candidate chunks using the RankGPT algorithm. See https://github.com/sunnweiwei/RankGPT for more details.
It uses LLM calls to rank the candidate chunks.
# Arguments
- `reranker`: Using Cohere API
- `index`: The index that holds the underlying chunks to be re-ranked.
- `question`: The query to be used for the search.
- `candidates`: The candidate chunks to be re-ranked.
- `top_n`: The number of most relevant documents to return. Default is `length(documents)`.
- `model`: The model to use for reranking. Default is `rerank-english-v3.0`.
- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `1`.
- `unique_chunks`: A boolean flag indicating whether to remove duplicates from the candidate chunks prior to reranking (saves compute time). Default is `true`.
# Examples
```julia
index = <some index>
question = "What are the best practices for parallel computing in Julia?"
cfg = RAGConfig(; retriever = SimpleRetriever(; reranker = RT.RankGPTReranker()))
msg = airag(cfg, index; question, return_all = true)
```
To get full verbosity of logs, set `verbose = 5` (anything higher than 3).
```julia
msg = airag(cfg, index; question, return_all = true, verbose = 5)
```
# Reference
[1] [Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al.](https://arxiv.org/abs/2304.09542)
[2] [RankGPT Github](https://github.com/sunnweiwei/RankGPT)
"""
function rerank(
reranker::RankGPTReranker, index::AbstractDocumentIndex, question::AbstractString,
candidates::AbstractCandidateChunks;
api_key::AbstractString = PT.OPENAI_API_KEY,
model::AbstractString = PT.MODEL_CHAT,
verbose::Bool = false,
top_n::Integer = length(candidates.scores),
unique_chunks::Bool = true,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
@assert top_n>0 "top_n must be a positive integer."
documents = index[candidates, :chunks]
@assert !(isempty(documents)) "The candidate chunks must not be empty! Check the index IDs."

is_multi_cand = candidates isa MultiCandidateChunks
index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id
positions = candidates.positions
## Find unique only items
if unique_chunks
verbose && @info "Removing duplicates from candidate chunks prior to reranking"
unique_idxs = PT.unique_permutation(documents)
documents = documents[unique_idxs]
positions = positions[unique_idxs]
index_ids = is_multi_cand ? index_ids[unique_idxs] : index_ids
end

## Run re-ranker via RankGPT
rank_end = max(get(kwargs, :rank_end, length(documents)), length(documents))
step = min(get(kwargs, :step, top_n), top_n, rank_end)
window_size = max(min(get(kwargs, :window_size, 20), rank_end), step)
verbose &&
@info "RankGPT parameters: rank_end = $rank_end, step = $step, window_size = $window_size"
result = rank_gpt(
documents, question; verbose = verbose * 3, api_key,
model, kwargs..., rank_end, step, window_size)

## Unwrap re-ranked positions
ranked_positions = first(result.positions, top_n)
positions = positions[ranked_positions]
## TODO: add reciprocal rank fusion and multiple passes
scores = ones(Float32, length(positions)) # no scores available

verbose && @info "Reranking done in $(round(result.elapsed; digits=1)) seconds."
Threads.atomic_add!(cost_tracker, result.cost)

return is_multi_cand ?
MultiCandidateChunks(index_ids[ranked_positions], positions, scores) :
CandidateChunks(index_ids, positions, scores)
end

### Overall types for `retrieve`
"""
SimpleRetriever <: AbstractRetriever
Expand Down
1 change: 1 addition & 0 deletions templates/RAG/ranking/RAGRankGPT.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"content":"Template Metadata","description":"RankGPT implementation to re-rank chunks by LLMs. Passages are injected in the middle - see the function. Placeholders: `num`, `question`","version":"1","source":"Based on https://github.com/sunnweiwei/RankGPT","_type":"metadatamessage"},{"content":"You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.","variables":[],"_type":"systemmessage"},{"content":"I will provide you with {{num}} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {{question}}.","variables":["num","question"],"_type":"usermessage"},{"content":"Okay, please provide the passages.","status":null,"tokens":[-1,-1],"elapsed":-1.0,"cost":null,"log_prob":null,"finish_reason":null,"run_id":-14760,"sample_id":null,"_type":"aimessage"},{"content":"Search Query: {{question}}. Rank the {{num}} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only respond with the ranking results, do not say any word or explain.","variables":["question","num"],"_type":"usermessage"}]
Loading

2 comments on commit 751f337

@svilupp
Copy link
Owner Author

@svilupp svilupp commented on 751f337 Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

  • RankGPT implementation for RAGTools chunk re-ranking pipeline. See ?RAGTools.Experimental.rank_gpt for more information and corresponding reranker type ?RankGPTReranker.

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/110121

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.34.0 -m "<description of version>" 751f337c28ba4d97367aece2d91665f684304039
git push origin v0.34.0

Please sign in to comment.