From 9da5906d9d36fea594a8a2e4e17d6b6b4a141a68 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Thu, 18 Apr 2024 09:50:12 +0100 Subject: [PATCH] Fix truncate_dimension (#137) --- CHANGELOG.md | 5 +++++ Project.toml | 2 +- src/Experimental/RAGTools/preparation.jl | 9 ++++++--- test/Experimental/RAGTools/preparation.jl | 4 ++++ 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cefb0f87..1e2e70d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.20.1] + +### Fixed +- Fixed `truncate_dimension` to ignore when 0 is provided (previously it would throw an error). + ## [0.20.0] ### Added diff --git a/Project.toml b/Project.toml index 6894bb07..4309bccc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.20.0" +version = "0.20.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl index 596d7255..904b2edf 100644 --- a/src/Experimental/RAGTools/preparation.jl +++ b/src/Experimental/RAGTools/preparation.jl @@ -200,7 +200,7 @@ Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched - `docs`: A vector of strings to be embedded. - `verbose`: A boolean flag for verbose output. Default is `true`. - `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`. -- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`. +- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`, `0` will also do nothing. - `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call. - `target_batch_size_length`: The target length (in characters) of each batch of document chunks sent for embedding. Default is 80_000 characters. Speeds up embedding process. - `ntasks`: The number of tasks to use for asyncmap. Default is 4 * Threads.nthreads(). @@ -237,14 +237,17 @@ function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:Abstract msg.content end embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents - if !isnothing(truncate_dimension) - @assert truncate_dimension>0 "Truncated dimensionality must be non-negative (Provided: $(truncate_dimension))" + # truncate_dimension=0 means that we skip it + if !isnothing(truncate_dimension) && truncate_dimension > 0 @assert truncate_dimension<=size(embeddings, 1) "Requested embeddings dimensionality is too high (Embeddings: $(size(embeddings)) vs dimensionality requested: $(truncate_dimension))" ## reduce + normalize again embeddings = embeddings[1:truncate_dimension, :] for i in axes(embeddings, 2) embeddings[:, i] = _normalize(embeddings[:, i]) end + elseif !isnothing(truncate_dimension) && truncate_dimension == 0 + # do nothing + verbose && @info "Truncate_dimension set to 0. Skipping truncation" end verbose && @info "Done embedding. Total cost: \$$(round(cost_tracker[],digits=3))" return embeddings diff --git a/test/Experimental/RAGTools/preparation.jl b/test/Experimental/RAGTools/preparation.jl index f36093f8..bdaedced 100644 --- a/test/Experimental/RAGTools/preparation.jl +++ b/test/Experimental/RAGTools/preparation.jl @@ -64,6 +64,10 @@ end output = get_embeddings( BatchEmbedder(), docs; model = "mock-emb", truncate_dimension = 100) @test size(output) == (100, 2) + ## value of 0 for truncation, skips the step + output = get_embeddings( + BatchEmbedder(), docs; model = "mock-emb", truncate_dimension = 0) + @test size(output) == (128, 2) # Unknown type struct RandomEmbedder123 <: AbstractEmbedder end