From 751f337c28ba4d97367aece2d91665f684304039 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 1 Jul 2024 10:31:23 +0100 Subject: [PATCH] Add RankGPT (#172) --- CHANGELOG.md | 5 + Project.toml | 5 +- ext/FlashRankPromptingToolsExt.jl | 2 +- src/Experimental/RAGTools/RAGTools.jl | 2 + src/Experimental/RAGTools/rank_gpt.jl | 166 ++++++++++ src/Experimental/RAGTools/retrieval.jl | 106 +++++++ templates/RAG/ranking/RAGRankGPT.json | 1 + test/Experimental/RAGTools/rank_gpt.jl | 389 ++++++++++++++++++++++++ test/Experimental/RAGTools/retrieval.jl | 3 +- test/Experimental/RAGTools/runtests.jl | 3 +- 10 files changed, 677 insertions(+), 5 deletions(-) create mode 100644 src/Experimental/RAGTools/rank_gpt.jl create mode 100644 templates/RAG/ranking/RAGRankGPT.json create mode 100644 test/Experimental/RAGTools/rank_gpt.jl diff --git a/CHANGELOG.md b/CHANGELOG.md index e8ba4391..5f02cde1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.34.0] + +### Added +- `RankGPT` implementation for RAGTools chunk re-ranking pipeline. See `?RAGTools.Experimental.rank_gpt` for more information and corresponding reranker type `?RankGPTReranker`. + ## [0.33.2] ### Fixed diff --git a/Project.toml b/Project.toml index 17d0dcf9..ae08c61d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.33.2" +version = "0.34.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -62,6 +62,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [targets] -test = ["Aqua", "FlashRank", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"] +test = ["Aqua", "FlashRank", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball", "Unicode"] diff --git a/ext/FlashRankPromptingToolsExt.jl b/ext/FlashRankPromptingToolsExt.jl index 9d62c3e7..6707a0ff 100644 --- a/ext/FlashRankPromptingToolsExt.jl +++ b/ext/FlashRankPromptingToolsExt.jl @@ -60,7 +60,7 @@ function RT.rerank( kwargs...) @assert top_n>0 "top_n must be a positive integer." documents = index[candidates, :chunks] - @assert !(isempty(documents)) "The candidate chunks must not be empty for Cohere Reranker! Check the index IDs." + @assert !(isempty(documents)) "The candidate chunks must not be empty! Check the index IDs." is_multi_cand = candidates isa RT.MultiCandidateChunks index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id diff --git a/src/Experimental/RAGTools/RAGTools.jl b/src/Experimental/RAGTools/RAGTools.jl index efd20c8e..733959c6 100644 --- a/src/Experimental/RAGTools/RAGTools.jl +++ b/src/Experimental/RAGTools/RAGTools.jl @@ -40,6 +40,8 @@ export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIn KeywordsIndexer include("preparation.jl") +include("rank_gpt.jl") + export retrieve, SimpleRetriever, SimpleBM25Retriever, AdvancedRetriever export find_closest, find_tags, rerank, rephrase include("retrieval.jl") diff --git a/src/Experimental/RAGTools/rank_gpt.jl b/src/Experimental/RAGTools/rank_gpt.jl new file mode 100644 index 00000000..e43326e3 --- /dev/null +++ b/src/Experimental/RAGTools/rank_gpt.jl @@ -0,0 +1,166 @@ +# Implementation of RankGPT +# Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al. // https://arxiv.org/abs/2304.09542 +# https://github.com/sunnweiwei/RankGPT + +""" + RankGPTResult + +Results from the RankGPT algorithm. + +# Fields +- `question::String`: The question that was asked. +- `chunks::AbstractVector{T}`: The chunks that were ranked (=context). +- `positions::Vector{Int}`: The ranking of the chunks (referring to the `chunks`). +- `elapsed::Float64`: The time it took to rank the chunks. +- `cost::Float64`: The cumulative cost of the ranking. +- `tokens::Int`: The cumulative number of tokens used in the ranking. +""" +@kwdef mutable struct RankGPTResult{T <: AbstractString} + question::String + chunks::AbstractVector{T} + positions::Vector{Int} = collect(1:length(chunks)) + elapsed::Float64 = 0.0 + cost::Float64 = 0.0 + tokens::Int = 0 +end +Base.show(io::IO, result::RankGPTResult) = dump(io, result; maxdepth = 1) + +""" + create_permutation_instruction( + context::AbstractVector{<:AbstractString}; rank_start::Integer = 1, + rank_end::Integer = 100, max_length::Integer = 512, template::Symbol = :RAGRankGPT) + +Creates rendered template with injected `context` passages. +""" +function create_permutation_instruction( + context::AbstractVector{<:AbstractString}; rank_start::Integer = 1, + rank_end::Integer = 100, max_length::Integer = 512, template::Symbol = :RAGRankGPT) + ## + rank_end_adj = min(rank_end, length(context)) + num = rank_end_adj - rank_start + 1 + + messages = PT.render(PT.AITemplate(template)) + last_msg = pop!(messages) + rank = 0 + for ctx in context[rank_start:rank_end_adj] + rank += 1 + push!(messages, PT.UserMessage("[$rank] $(strip(ctx)[1:min(end, max_length)])")) + push!(messages, PT.AIMessage("Received passage [$rank].")) + end + push!(messages, last_msg) + + return messages, num +end + +""" + extract_ranking(str::AbstractString) + +Extracts the ranking from the response into a sorted array of integers. +""" +function extract_ranking(str::AbstractString) + nums = replace(str, r"[^0-9]" => " ") |> strip |> split + nums = parse.(Int, nums) + unique_idxs = unique(i -> nums[i], eachindex(nums)) + return nums[unique_idxs] +end + +""" + receive_permutation!( + curr_rank::AbstractVector{<:Integer}, response::AbstractString; + rank_start::Integer = 1, rank_end::Integer = 100) + +Extracts and heals the permutation to contain all ranking positions. +""" +function receive_permutation!( + curr_rank::AbstractVector{<:Integer}, response::AbstractString; + rank_start::Integer = 1, rank_end::Integer = 100) + @assert rank_start>=1 "rank_start must be greater than or equal to 1" + @assert rank_end>=rank_start "rank_end must be greater than or equal to rank_start" + new_rank = extract_ranking(response) + copied_rank = curr_rank[rank_start:min(end, rank_end)] |> copy + orig_rank = 1:length(copied_rank) + new_rank = vcat( + [r for r in new_rank if r in orig_rank], [r for r in orig_rank if r ∉ new_rank]) + for (j, rnk) in enumerate(new_rank) + curr_rank[rank_start + j - 1] = copied_rank[rnk] + end + return curr_rank +end + +""" + permutation_step!( + result::RankGPTResult; rank_start::Integer = 1, rank_end::Integer = 100, kwargs...) + +One sub-step of the RankGPT algorithm permutation ranking within the window of chunks defined by `rank_start` and `rank_end` positions. +""" +function permutation_step!( + result::RankGPTResult; rank_start::Integer = 1, rank_end::Integer = 100, kwargs...) + (; positions, chunks, question) = result + tpl, num = create_permutation_instruction(chunks; rank_start, rank_end) + msg = aigenerate(tpl; question, num, kwargs...) + result.positions = receive_permutation!( + positions, PT.last_output(msg); rank_start, rank_end) + result.cost += msg.cost + result.tokens += sum(msg.tokens) + result.elapsed += msg.elapsed + return result +end + +""" + rank_sliding_window!( + result::RankGPTResult; verbose::Int = 1, rank_start = 1, rank_end = 100, + window_size = 20, step = 10, model::String = "gpt4o", kwargs...) + +One single pass of the RankGPT algorithm permutation ranking across all positions between `rank_start` and `rank_end`. +""" +function rank_sliding_window!( + result::RankGPTResult; verbose::Int = 1, rank_start = 1, rank_end = 100, + window_size = 20, step = 10, model::String = "gpt4o", kwargs...) + @assert rank_start>=0 "rank_start must be greater than or equal to 0 (Provided: rank_start=$rank_start)" + @assert rank_end>=rank_start "rank_end must be greater than or equal to rank_start (Provided: rank_end=$rank_end, rank_start=$rank_start)" + @assert rank_end>=window_size>=step "rank_end must be greater than or equal to window_size, which must be greater than or equal to step (Provided: rank_end=$rank_end, window_size=$window_size, step=$step)" + end_pos = min(rank_end, length(result.chunks)) + start_pos = max(end_pos - window_size, 1) + while start_pos >= rank_start + (verbose >= 1) && @info "Ranking chunks in positions $start_pos to $end_pos" + permutation_step!(result; rank_start = start_pos, rank_end = end_pos, + model, verbose = (verbose >= 1), kwargs...) + (verbose >= 2) && @info "Current ranking: $(result.positions)" + end_pos -= step + start_pos -= step + end + return result +end + +""" + rank_gpt(chunks::AbstractVector{<:AbstractString}, question::AbstractString; + verbose::Int = 1, rank_start::Integer = 1, rank_end::Integer = 100, + window_size::Integer = 20, step::Integer = 10, + num_rounds::Integer = 1, model::String = "gpt4o", kwargs...) + +Ranks the `chunks` based on their relevance for `question`. Returns the ranking permutation of the chunks in the order they are most relevant to the question (the first is the most relevant). + +# Example +```julia +result = rank_gpt(chunks, question; rank_start=1, rank_end=25, window_size=8, step=4, num_rounds=3, model="gpt4o") +``` + +# Reference +[1] [Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al.](https://arxiv.org/abs/2304.09542) +[2] [RankGPT Github](https://github.com/sunnweiwei/RankGPT) +""" +function rank_gpt(chunks::AbstractVector{<:AbstractString}, question::AbstractString; + verbose::Int = 1, rank_start::Integer = 1, rank_end::Integer = 100, + window_size::Integer = 20, step::Integer = 10, + num_rounds::Integer = 1, model::String = "gpt4o", kwargs...) + result = RankGPTResult(; question, chunks) + for i in 1:num_rounds + (verbose >= 1) && @info "Round $i of $num_rounds of ranking process." + result = rank_sliding_window!( + result; verbose = verbose - 1, rank_start, rank_end, + window_size, step, model, kwargs...) + end + (verbose >= 1) && + @info "Final ranking done. Tokens: $(result.tokens), Cost: $(round(result.cost, digits=2)), Time: $(round(result.elapsed, digits=1))s" + return result +end \ No newline at end of file diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 4a1dc634..9501bf91 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -595,6 +595,17 @@ struct FlashRanker{T} <: AbstractReranker model::T end +""" + RankGPTReranker <: AbstractReranker + +Rerank strategy using the RankGPT algorithm (calling LLMs). + +# Reference +[1] [Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al.](https://arxiv.org/abs/2304.09542) +[2] [RankGPT Github](https://github.com/sunnweiwei/RankGPT) +""" +struct RankGPTReranker <: AbstractReranker end + function rerank(reranker::AbstractReranker, index::AbstractDocumentIndex, question::AbstractString, candidates::AbstractCandidateChunks; kwargs...) throw(ArgumentError("Not implemented yet")) @@ -697,6 +708,101 @@ function rerank( CandidateChunks(index_ids, positions, scores) end +""" + rerank( + reranker::CohereReranker, index::AbstractDocumentIndex, question::AbstractString, + candidates::AbstractCandidateChunks; + verbose::Integer = 1, + api_key::AbstractString = PT.OPENAI_API_KEY, + top_n::Integer = length(candidates.scores), + model::AbstractString = PT.MODEL_CHAT, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + + +Re-ranks a list of candidate chunks using the RankGPT algorithm. See https://github.com/sunnweiwei/RankGPT for more details. + +It uses LLM calls to rank the candidate chunks. + +# Arguments +- `reranker`: Using Cohere API +- `index`: The index that holds the underlying chunks to be re-ranked. +- `question`: The query to be used for the search. +- `candidates`: The candidate chunks to be re-ranked. +- `top_n`: The number of most relevant documents to return. Default is `length(documents)`. +- `model`: The model to use for reranking. Default is `rerank-english-v3.0`. +- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `1`. +- `unique_chunks`: A boolean flag indicating whether to remove duplicates from the candidate chunks prior to reranking (saves compute time). Default is `true`. + +# Examples + +```julia +index = +question = "What are the best practices for parallel computing in Julia?" + +cfg = RAGConfig(; retriever = SimpleRetriever(; reranker = RT.RankGPTReranker())) +msg = airag(cfg, index; question, return_all = true) +``` +To get full verbosity of logs, set `verbose = 5` (anything higher than 3). +```julia +msg = airag(cfg, index; question, return_all = true, verbose = 5) +``` + + +# Reference +[1] [Is ChatGPT Good at Search? Investigating Large Language Models as Re-Ranking Agents by W. Sun et al.](https://arxiv.org/abs/2304.09542) +[2] [RankGPT Github](https://github.com/sunnweiwei/RankGPT) +""" +function rerank( + reranker::RankGPTReranker, index::AbstractDocumentIndex, question::AbstractString, + candidates::AbstractCandidateChunks; + api_key::AbstractString = PT.OPENAI_API_KEY, + model::AbstractString = PT.MODEL_CHAT, + verbose::Bool = false, + top_n::Integer = length(candidates.scores), + unique_chunks::Bool = true, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + @assert top_n>0 "top_n must be a positive integer." + documents = index[candidates, :chunks] + @assert !(isempty(documents)) "The candidate chunks must not be empty! Check the index IDs." + + is_multi_cand = candidates isa MultiCandidateChunks + index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id + positions = candidates.positions + ## Find unique only items + if unique_chunks + verbose && @info "Removing duplicates from candidate chunks prior to reranking" + unique_idxs = PT.unique_permutation(documents) + documents = documents[unique_idxs] + positions = positions[unique_idxs] + index_ids = is_multi_cand ? index_ids[unique_idxs] : index_ids + end + + ## Run re-ranker via RankGPT + rank_end = max(get(kwargs, :rank_end, length(documents)), length(documents)) + step = min(get(kwargs, :step, top_n), top_n, rank_end) + window_size = max(min(get(kwargs, :window_size, 20), rank_end), step) + verbose && + @info "RankGPT parameters: rank_end = $rank_end, step = $step, window_size = $window_size" + result = rank_gpt( + documents, question; verbose = verbose * 3, api_key, + model, kwargs..., rank_end, step, window_size) + + ## Unwrap re-ranked positions + ranked_positions = first(result.positions, top_n) + positions = positions[ranked_positions] + ## TODO: add reciprocal rank fusion and multiple passes + scores = ones(Float32, length(positions)) # no scores available + + verbose && @info "Reranking done in $(round(result.elapsed; digits=1)) seconds." + Threads.atomic_add!(cost_tracker, result.cost) + + return is_multi_cand ? + MultiCandidateChunks(index_ids[ranked_positions], positions, scores) : + CandidateChunks(index_ids, positions, scores) +end + ### Overall types for `retrieve` """ SimpleRetriever <: AbstractRetriever diff --git a/templates/RAG/ranking/RAGRankGPT.json b/templates/RAG/ranking/RAGRankGPT.json new file mode 100644 index 00000000..f4526052 --- /dev/null +++ b/templates/RAG/ranking/RAGRankGPT.json @@ -0,0 +1 @@ +[{"content":"Template Metadata","description":"RankGPT implementation to re-rank chunks by LLMs. Passages are injected in the middle - see the function. Placeholders: `num`, `question`","version":"1","source":"Based on https://github.com/sunnweiwei/RankGPT","_type":"metadatamessage"},{"content":"You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.","variables":[],"_type":"systemmessage"},{"content":"I will provide you with {{num}} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {{question}}.","variables":["num","question"],"_type":"usermessage"},{"content":"Okay, please provide the passages.","status":null,"tokens":[-1,-1],"elapsed":-1.0,"cost":null,"log_prob":null,"finish_reason":null,"run_id":-14760,"sample_id":null,"_type":"aimessage"},{"content":"Search Query: {{question}}. Rank the {{num}} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only respond with the ranking results, do not say any word or explain.","variables":["question","num"],"_type":"usermessage"}] \ No newline at end of file diff --git a/test/Experimental/RAGTools/rank_gpt.jl b/test/Experimental/RAGTools/rank_gpt.jl new file mode 100644 index 00000000..55834f2a --- /dev/null +++ b/test/Experimental/RAGTools/rank_gpt.jl @@ -0,0 +1,389 @@ +using PromptingTools.Experimental.RAGTools: RankGPTResult, create_permutation_instruction, + extract_ranking, receive_permutation!, + permutation_step!, rank_sliding_window!, + rank_gpt +using PromptingTools: TestEchoOpenAISchema + +@testset "RankGPTResult" begin + # Test creation of RankGPTResult with default parameters + result = RankGPTResult(question = "What is AI?", chunks = ["chunk1", "chunk2"]) + @test result.question == "What is AI?" # Check question + @test result.chunks == ["chunk1", "chunk2"] # Check chunks + @test result.positions == [1, 2] # Check default positions + @test result.elapsed == 0.0 # Check default elapsed time + @test result.cost == 0.0 # Check default cost + @test result.tokens == 0 # Check default tokens + + # Test creation of RankGPTResult with custom positions + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2"], positions = [2, 1]) + @test result.positions == [2, 1] # Check custom positions + + # Test creation of RankGPTResult with custom elapsed time, cost, and tokens + result = RankGPTResult(question = "What is AI?", chunks = ["chunk1", "chunk2"], + elapsed = 5.0, cost = 10.0, tokens = 100) + @test result.elapsed == 5.0 # Check custom elapsed time + @test result.cost == 10.0 # Check custom cost + @test result.tokens == 100 # Check custom tokens + + # Test show method for RankGPTResult + io = IOBuffer() + show(io, result) + output = String(take!(io)) + @test occursin("question:", output) # Check if question is in the output + @test occursin("What is AI?", output) + @test occursin("chunks:", output) # Check if chunks are in the output + @test occursin("positions:", output) # Check if positions are in the output + @test occursin("elapsed:", output) # Check if elapsed time is in the output + @test occursin("cost:", output) # Check if cost is in the output + @test occursin("tokens:", output) # Check if tokens are in the output + + # Test creation of RankGPTResult with empty chunks + result = RankGPTResult(question = "What is AI?", chunks = String[]) + @test result.chunks == String[] # Check empty chunks + @test result.positions == [] # Check positions for empty chunks +end + +@testset "create_permutation_instruction" begin + # Test with basic context and default parameters + context = ["This is a test.", "Another test document."] + messages, num = create_permutation_instruction(context) + @test num == 2 # Check number of messages + @test length(messages) == 4 + 4 # Check total messages including AI responses + @test messages[begin] isa PT.SystemMessage # Check first message type + @test messages[4].content == "[1] This is a test." + @test messages[5].content == "Received passage [1]." + @test messages[6].content == "[2] Another test document." + @test messages[7].content == "Received passage [2]." + @test messages[end] isa PT.UserMessage # Check second message type + + # Test with custom rank_start and rank_end + messages, num = create_permutation_instruction(context; rank_start = 2, rank_end = 2) + @test num == 1 # Check number of messages + @test length(messages) == 4 + 2 * 1 # Check total messages including AI responses + @test messages[begin] isa PT.SystemMessage # Check first message type + @test messages[4].content == "[1] Another test document." + @test messages[5].content == "Received passage [1]." + @test messages[end] isa PT.UserMessage # Check second message type + + # Test with max_length parameter + long_context = ["This is a very long test document that exceeds the max length parameter."] + messages, num = create_permutation_instruction(long_context; max_length = 10) + @test num == 1 # Check number of messages + @test length(messages) == 4 + 2 * 1 # Check total messages including AI responses + @test length(messages[4].content) <= 10 + 5 # Check if content is truncated (+5 for the markers at the beginning) + + # Test with different template + @test_throws ErrorException create_permutation_instruction( + context; template = :AnotherTemplateNotExist) + + # Test with empty context + empty_context = String[] + messages, num = create_permutation_instruction(empty_context) + @test num == 0 # Check number of messages + @test length(messages) == 4 # Check total messages including AI responses +end + +@testset "extract_ranking" begin + @test extract_ranking("asdas1asdas") == [1] # Test single number + @test extract_ranking("[1] > [2] > [3]") == [1, 2, 3] # Test multiple numbers + @test extract_ranking("[3] > [2] > [1]") == [3, 2, 1] # Test numbers in reverse order + @test extract_ranking("[1], [2], [3]") == [1, 2, 3] # Test numbers with commas + @test extract_ranking("[1] > [2] > [2] > [3]") == [1, 2, 3] # Test duplicate numbers + @test extract_ranking("[1] > [2] > [3] > [3] > [2] > [1]") == [1, 2, 3] # Test multiple duplicates + @test extract_ranking("a1b2c3") == [1, 2, 3] # Test numbers with letters + @test extract_ranking("[1] > [2] > [3] > [4] > [5] > [6] > [7] > [8] > [9] > [10]") == + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Test larger range of numbers + @test extract_ranking("10 9 8 7 6 5 4 3 2 1") == [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] # Test larger range in reverse order + @test extract_ranking("1 2 3 4 5 6 7 8 9 10 10 9 8 7 6 5 4 3 2 1") == + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Test larger range with duplicates +end + +@testset "receive_permutation!" begin + # Test with basic ranking and response + curr_rank = [1, 2, 3] + response = "[1] > [2] > [3]" + @test receive_permutation!(curr_rank, response) == [1, 2, 3] # Basic case + + # Test with reversed ranking in response + curr_rank = [1, 2, 3] + response = "[3] > [2] > [1]" + @test receive_permutation!(curr_rank, response) == [3, 2, 1] # Reversed ranking + + # Test with missing ranks in response + curr_rank = [1, 2, 3, 4, 5] + response = "[5] > [3] > [1]" + @test receive_permutation!(curr_rank, response) == [5, 3, 1, 2, 4] # Missing ranks + + # Test with extra ranks in response + curr_rank = [1, 2, 3] + response = "[1] > [2] > [3] > [4] > [5]" + @test receive_permutation!(curr_rank, response) == [1, 2, 3] # Extra ranks + + # Test with duplicate ranks in response + curr_rank = [1, 2, 3] + response = "[1] > [2] > [2] > [3]" + @test receive_permutation!(curr_rank, response) == [1, 2, 3] # Duplicate ranks + + # Test with non-sequential ranks in response + curr_rank = [1, 2, 3, 4, 5] + response = "[5] > [1] > [3]" + @test receive_permutation!(curr_rank, response) == [5, 1, 3, 2, 4] # Non-sequential ranks + + # Test with rank_start and rank_end parameters + curr_rank = [1, 2, 3, 4, 5] + response = "[4] > [5]" + @test receive_permutation!(curr_rank, response; rank_start = 4, rank_end = 5) == + [1, 2, 3, 4, 5] # Rank start and end + + # Test with rank_start and rank_end parameters, non-sequential + curr_rank = [1, 2, 3, 4, 5] + response = "[2] > [1]" + @test receive_permutation!(curr_rank, response; rank_start = 4, rank_end = 5) == + [1, 2, 3, 5, 4] # Rank start and end, non-sequential + + # Test with rank_start and rank_end parameters, missing ranks + curr_rank = [1, 2, 3, 4, 5] + response = "[2]" + @test receive_permutation!(curr_rank, response; rank_start = 4, rank_end = 5) == + [1, 2, 3, 5, 4] # Rank start and end, missing ranks + + # Test with rank_start and rank_end parameters, duplicate ranks + curr_rank = [1, 2, 3, 4, 5] + response = "[2 ] > [2]" + @test receive_permutation!(curr_rank, response; rank_start = 4, rank_end = 5) == + [1, 2, 3, 5, 4] # Rank start and end, duplicate ranks +end + +@testset "permutation_step!" begin + # Mocking the aigenerate function + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[1] > [2]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + + # Simple case with default parameters + result = RankGPTResult(question = "What is AI?", chunks = ["chunk1", "chunk2"]) + @test permutation_step!(result; model = "mock-gen").positions == [1, 2] # Simple case + + # Case with more chunks + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test permutation_step!(result; model = "mock-gen").positions == [1, 2, 3] # More chunks + + # Case with rank_start and rank_end parameters + result = RankGPTResult(question = "What is AI?", + chunks = ["chunk1", "chunk2", "chunk3", "chunk4", "chunk5"]) + @test permutation_step!( + result; rank_start = 2, rank_end = 4, model = "mock-gen").positions == + [1, 2, 3, 4, 5] # Rank start and end + + # Case with non-sequential ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[3] > [1] > [2]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test permutation_step!(result; model = "mock-gen").positions == [3, 1, 2] # Non-sequential ranks + + # Case with duplicate ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[2] > [2] > [1]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test permutation_step!(result; model = "mock-gen").positions == [2, 1, 3] # Duplicate ranks + + # Case with missing ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[1] > [3]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test permutation_step!(result; model = "mock-gen").positions == [1, 3, 2] # Missing ranks +end + +@testset "rank_sliding_window!" begin + # Mocking the aigenerate function + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[1] > [2]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + # Simple case with default parameters + result = RankGPTResult(question = "What is AI?", chunks = ["chunk1", "chunk2"]) + @test rank_sliding_window!(result; model = "mock-gen").positions == [1, 2] # Simple case + + # Case with more chunks + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test rank_sliding_window!(result; model = "mock-gen").positions == [1, 2, 3] # More chunks + + # Case with rank_start and rank_end parameters + result = RankGPTResult(question = "What is AI?", + chunks = ["chunk1", "chunk2", "chunk3", "chunk4", "chunk5"]) + @test rank_sliding_window!( + result; rank_start = 2, rank_end = 4, window_size = 2, + step = 2, model = "mock-gen").positions == + [1, 2, 3, 4, 5] # Rank start and end + + # Case with non-sequential ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[3] > [1] > [2]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test rank_sliding_window!(result; model = "mock-gen").positions == [3, 1, 2] # Non-sequential ranks + + # Case with duplicate ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[2] > [2] > [1]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test rank_sliding_window!(result; model = "mock-gen").positions == [2, 1, 3] # Duplicate ranks + + # Case with missing ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[1] > [3]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test rank_sliding_window!(result; model = "mock-gen").positions == [1, 3, 2] # Missing ranks + + ## Wrong inputs + result = RankGPTResult( + question = "What is AI?", chunks = ["chunk1", "chunk2", "chunk3"]) + @test_throws AssertionError rank_sliding_window!( + result; rank_start = 2, rank_end = 4, window_size = 2, + step = 3) + @test_throws AssertionError rank_sliding_window!( + result; rank_start = 2, rank_end = 4, window_size = 5, + step = 1) + @test_throws AssertionError rank_sliding_window!( + result; rank_start = 2, rank_end = 4) +end + +@testset "rank_gpt" begin + response = Dict( + :choices => [ + Dict( + :message => Dict(:content => "[4] > [2] > [3] > [1]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + # Test with basic chunks and question + result = rank_gpt(["chunk1", "chunk2"], "What is AI?"; model = "mock-gen") + @test result.question == "What is AI?" # Check question + @test result.chunks == ["chunk1", "chunk2"] # Check chunks + @test result.positions == [2, 1] # Check default positions + + # Test with custom rank_start and rank_end + result = rank_gpt(["chunk1", "chunk2", "chunk3", "chunk4"], + "What is AI?"; rank_start = 2, rank_end = 3, window_size = 3, step = 2, model = "mock-gen") + @test result.positions == [1, 2, 3, 4] # Check positions with custom rank_start and rank_end + result = rank_gpt(["chunk1", "chunk2", "chunk3", "chunk4"], + "What is AI?"; rank_start = 1, rank_end = 4, window_size = 4, + step = 2, model = "mock-gen") + @test result.positions == [4, 2, 3, 1] # Check positions with custom rank_start and rank_end + + # Test with window_size and step + result = rank_gpt( + ["chunk1", "chunk2", "chunk3", "chunk4"], "What is AI?"; window_size = 4, step = 4, model = "mock-gen") + @test result.positions == [4, 2, 3, 1] # Check positions with window_size and step + + # Test with multiple rounds + result = rank_gpt( + ["chunk1", "chunk2", "chunk3", "chunk4"], "What is AI?"; num_rounds = 2, model = "mock-gen", verbose = 0) + @test result.positions == [1, 2, 3, 4] # Check positions with multiple rounds (flips twice) + result = rank_gpt( + ["chunk1", "chunk2", "chunk3", "chunk4"], "What is AI?"; num_rounds = 3, model = "mock-gen", verbose = 0) + @test result.positions == [4, 2, 3, 1] # Check positions with multiple rounds (flips twice) + + # Test with non-sequential ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[3] > [1] > [2]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = rank_gpt(["chunk1", "chunk2", "chunk3"], "What is AI?"; model = "mock-gen") + @test result.positions == [3, 1, 2] # Check non-sequential ranks + + # Test with duplicate ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[2] > [2] > [1]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = rank_gpt(["chunk1", "chunk2", "chunk3"], "What is AI?"; model = "mock-gen") + @test result.positions == [2, 1, 3] # Check duplicate ranks + + # Test with missing ranks in response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "[1] > [3]"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = rank_gpt(["chunk1", "chunk2", "chunk3"], "What is AI?"; model = "mock-gen") + @test result.positions == [1, 3, 2] # Check missing ranks +end \ No newline at end of file diff --git a/test/Experimental/RAGTools/retrieval.jl b/test/Experimental/RAGTools/retrieval.jl index e7720409..3ac17de7 100644 --- a/test/Experimental/RAGTools/retrieval.jl +++ b/test/Experimental/RAGTools/retrieval.jl @@ -6,7 +6,8 @@ using PromptingTools.Experimental.RAGTools: ContextEnumerator, NoRephraser, Simp NoTagFilter, AnyTagFilter, SimpleRetriever, AdvancedRetriever using PromptingTools.Experimental.RAGTools: AbstractRephraser, AbstractTagFilter, - AbstractSimilarityFinder, AbstractReranker + AbstractSimilarityFinder, AbstractReranker, + RankGPTReranker using PromptingTools.Experimental.RAGTools: find_closest, hamming_distance, find_tags, rerank, rephrase, retrieve, HasEmbeddings, MultiCandidateChunks, diff --git a/test/Experimental/RAGTools/runtests.jl b/test/Experimental/RAGTools/runtests.jl index 143eea05..4ef02776 100644 --- a/test/Experimental/RAGTools/runtests.jl +++ b/test/Experimental/RAGTools/runtests.jl @@ -1,5 +1,5 @@ using Test -using SparseArrays, LinearAlgebra +using SparseArrays, LinearAlgebra, Unicode using PromptingTools.Experimental.RAGTools using PromptingTools using PromptingTools.AbstractTrees @@ -12,6 +12,7 @@ using JSON3, HTTP include("utils.jl") include("types.jl") include("preparation.jl") + include("rank_gpt.jl") include("retrieval.jl") include("generation.jl") include("annotation.jl")