Skip to content

Commit

Permalink
Fix truncate_dimension (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Apr 18, 2024
1 parent 46f6770 commit 9da5906
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.20.1]

### Fixed
- Fixed `truncate_dimension` to ignore when 0 is provided (previously it would throw an error).

## [0.20.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.20.0"
version = "0.20.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
9 changes: 6 additions & 3 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched
- `docs`: A vector of strings to be embedded.
- `verbose`: A boolean flag for verbose output. Default is `true`.
- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`.
- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`.
- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`, `0` will also do nothing.
- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call.
- `target_batch_size_length`: The target length (in characters) of each batch of document chunks sent for embedding. Default is 80_000 characters. Speeds up embedding process.
- `ntasks`: The number of tasks to use for asyncmap. Default is 4 * Threads.nthreads().
Expand Down Expand Up @@ -237,14 +237,17 @@ function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:Abstract
msg.content
end
embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents
if !isnothing(truncate_dimension)
@assert truncate_dimension>0 "Truncated dimensionality must be non-negative (Provided: $(truncate_dimension))"
# truncate_dimension=0 means that we skip it
if !isnothing(truncate_dimension) && truncate_dimension > 0
@assert truncate_dimension<=size(embeddings, 1) "Requested embeddings dimensionality is too high (Embeddings: $(size(embeddings)) vs dimensionality requested: $(truncate_dimension))"
## reduce + normalize again
embeddings = embeddings[1:truncate_dimension, :]
for i in axes(embeddings, 2)
embeddings[:, i] = _normalize(embeddings[:, i])
end
elseif !isnothing(truncate_dimension) && truncate_dimension == 0
# do nothing
verbose && @info "Truncate_dimension set to 0. Skipping truncation"
end
verbose && @info "Done embedding. Total cost: \$$(round(cost_tracker[],digits=3))"
return embeddings
Expand Down
4 changes: 4 additions & 0 deletions test/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ end
output = get_embeddings(
BatchEmbedder(), docs; model = "mock-emb", truncate_dimension = 100)
@test size(output) == (100, 2)
## value of 0 for truncation, skips the step
output = get_embeddings(
BatchEmbedder(), docs; model = "mock-emb", truncate_dimension = 0)
@test size(output) == (128, 2)

# Unknown type
struct RandomEmbedder123 <: AbstractEmbedder end
Expand Down

2 comments on commit 9da5906

@svilupp
Copy link
Owner Author

@svilupp svilupp commented on 9da5906 Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Fixed

  • Fixed truncate_dimension to ignore when 0 is provided (previously it would throw an error).

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105148

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.20.1 -m "<description of version>" 9da5906d9d36fea594a8a2e4e17d6b6b4a141a68
git push origin v0.20.1

Please sign in to comment.