Skip to content

Commit

Permalink
adds encoding functions; fixes bugs; adds references
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Mar 15, 2023
1 parent 2cbfba4 commit 799e5ef
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/KCenters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include("enet.jl")
include("dnet.jl")
include("utils.jl")
include("clustering.jl")
include("refs.jl")
include("proj.jl")

import Distances: evaluate
Expand Down
1 change: 0 additions & 1 deletion src/clustering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ function kcenters(dist::SemiMetric, X::AbstractDatabase, k::Integer; sel::Abstra
kcenters_(dist, X, initial, sel=sel, maxiters=maxiters, tol=tol, recall=recall, verbose=verbose)
end


function kcenters_(dist::SemiMetric, X::AbstractDatabase, C::AbstractDatabase; sel::AbstractCenterSelection=CentroidSelection(), maxiters=-1, tol=0.001, recall=1.0, verbose=true)
# Lloyd's algoritm
n = length(X)
Expand Down
12 changes: 7 additions & 5 deletions src/proj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using Polyester

export AbstractReferenceMapping, Knr, Perms, BinPerms, BinWalk
export encode_database, encode_database!, encode_object, encode_object!
export encode_database, encode_database!, encode_object, encode_object!, encode_object_res!

abstract type AbstractReferenceMapping end

Expand Down Expand Up @@ -46,7 +46,7 @@ function getpermscache(m)
c
end

function geteknrcache(k::Integer, pools)
function getknrcache(k::Integer, pools=nothing)
reuse!(KNR_CACHES[Threads.threadid()], k)
end

Expand Down Expand Up @@ -83,7 +83,7 @@ end

function encode_object!(knr::Knr, out::AbstractVector, obj)
k = length(out)
res = SimilaritySearch.getknnresult(k)
res = getknrcache(k)
search(knr.refs, obj, res)
k_ = length(res)
o = @view out[1:k_]
Expand All @@ -98,11 +98,13 @@ function encode_database(knr::Knr, S::AbstractDatabase; k::Integer=knr.k, minbat
encode_database!(knr, X, S; minbatch)
end

function encode_object(knr::Knr, obj, refs::AbstractSearchIndex; k::Integer=knr.k, minbatch=0)
function encode_object(knr::Knr, obj; k::Integer=knr.k)
X = Vector{knr.itype}(undef, k)
encode_object!(knr, X, S; minbatch)
encode_object!(knr, X, obj)
end

encode_object_res!(knr::Knr, res::KnnResult, obj) = search(knr.refs, obj, res).res
encode_object_res!(knr::Knr, obj; k=knr.k) = search(knr.refs, obj, getknrcache(k)).res

struct Perms{DistType<:SemiMetric,DbType<:AbstractDatabase} <: AbstractReferenceMapping
dist::DistType
Expand Down
90 changes: 90 additions & 0 deletions src/refs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
export references

"""
references(
weighting_centers::Function,
dist::SemiMetric, db::AbstractDatabase, k::Integer;
Δ=1.5,
sample=Δ*k + sqrt(length(db)),
maxiters=0,
tol=0.001,
initial=:rand)
references(dist::SemiMetric, db::AbstractDatabase, k::Integer; kwargs...)
Computes a set of `k` references from `db`, see the [`kcenters`](@ref) documentation.
More precisely, the references will be computed from a sample subset (`sample`);
it computes Δ k references and select the `k` elements using the best ones w.r.t. `weighting_centers(C, i)` function
(where `C` is a `ClusteringData` object and `i` the i-th center).
The set of references are meaninful under `dist` metric function but also may follow some
characteristics given by the `initial` selection strategy.
# Arguments
- `dist`: a distance function
- `db`: the database to be sampled
- `k`: the number of centers to compute
# Keyword arguments
- `sample::Real`: indicates the sampling size before computing the set of `k` references, defaults to `log(|db|) k`; `sample=0` means for no sampling.
- `Δ::Real`: expands the number of candidates to be selected as references
- `maxiters::Int`: number of iterationso of the Lloyd algorithm that should be applied on the initial computation of centers, that is, `maxiters > 0` applies `maxiters` iterations of the algorithm.
- `tol::Float64`: change tolerance to stop the Lloyd algorithm (error changes smaller than `tol` among iterations will stop the algorithm)
- `initial`: initial centers or a strategy to compute initial centers, symbols `:rand`, `:fft`, and `:dnet`.
There are several interactions between initial values and parameter interactions (described in `KCenters` object), for instance,
the `maxiters > 0` apply the Lloyd's algorithm to the initial computation of references.
- if `initial=:rand`:
- `maxiters = 0` will retrieve a simple random sampling
- `maxiters > 0' achieve kmeans-centroids, `maxiters` should be set appropiately for the the dataset
- if `initial=:dnet`:
- `maxiters = 0` computes a pure density-net
- `maxiters > 0` will compute a kmeans centroids but with an initialization based on the dnet
- if `initial=:fft`:
- `maxiters = 0` computes `k` centers with the farthest first traversal algorithm
- `maxiters > 0` will use the FFT based kcenters as initial points for the Lloyd algorithm
Note 1: `maxiters > 0` needs to compute centroids and these centroids should be _defined_
for the specific data model, and also be of use in the specific metric distance and error function.
Note 2: The error function is defined as the mean of distances of all objects in `db` to its associated nearest centers in each iteration.
Note 3: The computation of references on large databases can be prohibitive, in these cases please consider to work on a sample of `db`
"""
function references(
weighting_centers::Function,
dist::SemiMetric, db::AbstractDatabase, k::Integer;
Δ=1.5,
sample=Δ*k + sqrt(length(db)),
maxiters=0,
tol=0.001,
initial=:rand)

n = length(db)
if n == k
sample = 0
Δk = k
else
sample = ceil(Int, sample)
Δk = ceil(Int, Δ * k)
end

0 < k <= n || throw(ArgumentError("invalid relation between k and n, must follow 0 < k <= n"))
C = if sample > 0
s = unique(rand(eachindex(db), sample))
kcenters(dist, SubDatabase(db, s), Δk; initial, maxiters, tol)
else
kcenters(dist, db, Δk; initial, maxiters, tol)
end

W = [weighting_centers(C, i) for i in eachindex(C.centers)]
P = sortperm(W; rev=true)
C.centers[P[1:k]]
end

function references(dist::SemiMetric, db::AbstractDatabase, k::Integer; kwargs...)
references(dist, db, k; kwargs...) do C, i
f = C.freqs[i]
f * sign(C.dmax[i])
end
end
4 changes: 2 additions & 2 deletions test/proj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ using Random, SimilaritySearch, KCenters, StatsBase
Igold, _ = searchbatch(ExhaustiveSearch(; dist, db), queries, k)

refs = let
R = kcenters(dist, db, 64; initial=:rand)
refs = MatrixDatabase(R.centers)
R = references(dist, db, 64; initial=:rand)
refs = MatrixDatabase(R)
ExhaustiveSearch(; dist, db=refs)
end

Expand Down

2 comments on commit 799e5ef

@sadit
Copy link
Owner Author

@sadit sadit commented on 799e5ef Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/79658

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.1 -m "<description of version>" 799e5ef628c1c3d96f83061d31dca97348039aa9
git push origin v0.8.1

Please sign in to comment.