From 799e5ef628c1c3d96f83061d31dca97348039aa9 Mon Sep 17 00:00:00 2001 From: "Eric S. Tellez" Date: Wed, 15 Mar 2023 12:55:52 -0700 Subject: [PATCH] adds encoding functions; fixes bugs; adds references --- src/KCenters.jl | 1 + src/clustering.jl | 1 - src/proj.jl | 12 ++++--- src/refs.jl | 90 +++++++++++++++++++++++++++++++++++++++++++++++ test/proj.jl | 4 +-- 5 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 src/refs.jl diff --git a/src/KCenters.jl b/src/KCenters.jl index a446735..0c91c57 100644 --- a/src/KCenters.jl +++ b/src/KCenters.jl @@ -10,6 +10,7 @@ include("enet.jl") include("dnet.jl") include("utils.jl") include("clustering.jl") +include("refs.jl") include("proj.jl") import Distances: evaluate diff --git a/src/clustering.jl b/src/clustering.jl index fd34b9e..952c49c 100644 --- a/src/clustering.jl +++ b/src/clustering.jl @@ -93,7 +93,6 @@ function kcenters(dist::SemiMetric, X::AbstractDatabase, k::Integer; sel::Abstra kcenters_(dist, X, initial, sel=sel, maxiters=maxiters, tol=tol, recall=recall, verbose=verbose) end - function kcenters_(dist::SemiMetric, X::AbstractDatabase, C::AbstractDatabase; sel::AbstractCenterSelection=CentroidSelection(), maxiters=-1, tol=0.001, recall=1.0, verbose=true) # Lloyd's algoritm n = length(X) diff --git a/src/proj.jl b/src/proj.jl index 9d871e1..d762bd1 100644 --- a/src/proj.jl +++ b/src/proj.jl @@ -5,7 +5,7 @@ using Polyester export AbstractReferenceMapping, Knr, Perms, BinPerms, BinWalk -export encode_database, encode_database!, encode_object, encode_object! +export encode_database, encode_database!, encode_object, encode_object!, encode_object_res! abstract type AbstractReferenceMapping end @@ -46,7 +46,7 @@ function getpermscache(m) c end -function geteknrcache(k::Integer, pools) +function getknrcache(k::Integer, pools=nothing) reuse!(KNR_CACHES[Threads.threadid()], k) end @@ -83,7 +83,7 @@ end function encode_object!(knr::Knr, out::AbstractVector, obj) k = length(out) - res = SimilaritySearch.getknnresult(k) + res = getknrcache(k) search(knr.refs, obj, res) k_ = length(res) o = @view out[1:k_] @@ -98,11 +98,13 @@ function encode_database(knr::Knr, S::AbstractDatabase; k::Integer=knr.k, minbat encode_database!(knr, X, S; minbatch) end -function encode_object(knr::Knr, obj, refs::AbstractSearchIndex; k::Integer=knr.k, minbatch=0) +function encode_object(knr::Knr, obj; k::Integer=knr.k) X = Vector{knr.itype}(undef, k) - encode_object!(knr, X, S; minbatch) + encode_object!(knr, X, obj) end +encode_object_res!(knr::Knr, res::KnnResult, obj) = search(knr.refs, obj, res).res +encode_object_res!(knr::Knr, obj; k=knr.k) = search(knr.refs, obj, getknrcache(k)).res struct Perms{DistType<:SemiMetric,DbType<:AbstractDatabase} <: AbstractReferenceMapping dist::DistType diff --git a/src/refs.jl b/src/refs.jl new file mode 100644 index 0000000..176738f --- /dev/null +++ b/src/refs.jl @@ -0,0 +1,90 @@ +export references + +""" + references( + weighting_centers::Function, + dist::SemiMetric, db::AbstractDatabase, k::Integer; + Δ=1.5, + sample=Δ*k + sqrt(length(db)), + maxiters=0, + tol=0.001, + initial=:rand) + references(dist::SemiMetric, db::AbstractDatabase, k::Integer; kwargs...) + +Computes a set of `k` references from `db`, see the [`kcenters`](@ref) documentation. + +More precisely, the references will be computed from a sample subset (`sample`); +it computes Δ k references and select the `k` elements using the best ones w.r.t. `weighting_centers(C, i)` function +(where `C` is a `ClusteringData` object and `i` the i-th center). +The set of references are meaninful under `dist` metric function but also may follow some +characteristics given by the `initial` selection strategy. + +# Arguments + +- `dist`: a distance function +- `db`: the database to be sampled +- `k`: the number of centers to compute + +# Keyword arguments +- `sample::Real`: indicates the sampling size before computing the set of `k` references, defaults to `log(|db|) k`; `sample=0` means for no sampling. +- `Δ::Real`: expands the number of candidates to be selected as references +- `maxiters::Int`: number of iterationso of the Lloyd algorithm that should be applied on the initial computation of centers, that is, `maxiters > 0` applies `maxiters` iterations of the algorithm. +- `tol::Float64`: change tolerance to stop the Lloyd algorithm (error changes smaller than `tol` among iterations will stop the algorithm) +- `initial`: initial centers or a strategy to compute initial centers, symbols `:rand`, `:fft`, and `:dnet`. +There are several interactions between initial values and parameter interactions (described in `KCenters` object), for instance, +the `maxiters > 0` apply the Lloyd's algorithm to the initial computation of references. + +- if `initial=:rand`: + - `maxiters = 0` will retrieve a simple random sampling + - `maxiters > 0' achieve kmeans-centroids, `maxiters` should be set appropiately for the the dataset +- if `initial=:dnet`: + - `maxiters = 0` computes a pure density-net + - `maxiters > 0` will compute a kmeans centroids but with an initialization based on the dnet +- if `initial=:fft`: + - `maxiters = 0` computes `k` centers with the farthest first traversal algorithm + - `maxiters > 0` will use the FFT based kcenters as initial points for the Lloyd algorithm + +Note 1: `maxiters > 0` needs to compute centroids and these centroids should be _defined_ +for the specific data model, and also be of use in the specific metric distance and error function. + +Note 2: The error function is defined as the mean of distances of all objects in `db` to its associated nearest centers in each iteration. + +Note 3: The computation of references on large databases can be prohibitive, in these cases please consider to work on a sample of `db` +""" +function references( + weighting_centers::Function, + dist::SemiMetric, db::AbstractDatabase, k::Integer; + Δ=1.5, + sample=Δ*k + sqrt(length(db)), + maxiters=0, + tol=0.001, + initial=:rand) + + n = length(db) + if n == k + sample = 0 + Δk = k + else + sample = ceil(Int, sample) + Δk = ceil(Int, Δ * k) + end + + 0 < k <= n || throw(ArgumentError("invalid relation between k and n, must follow 0 < k <= n")) + C = if sample > 0 + s = unique(rand(eachindex(db), sample)) + kcenters(dist, SubDatabase(db, s), Δk; initial, maxiters, tol) + else + kcenters(dist, db, Δk; initial, maxiters, tol) + end + + W = [weighting_centers(C, i) for i in eachindex(C.centers)] + P = sortperm(W; rev=true) + C.centers[P[1:k]] +end + +function references(dist::SemiMetric, db::AbstractDatabase, k::Integer; kwargs...) + references(dist, db, k; kwargs...) do C, i + f = C.freqs[i] + f * sign(C.dmax[i]) + end +end diff --git a/test/proj.jl b/test/proj.jl index 190f49f..d324d83 100644 --- a/test/proj.jl +++ b/test/proj.jl @@ -13,8 +13,8 @@ using Random, SimilaritySearch, KCenters, StatsBase Igold, _ = searchbatch(ExhaustiveSearch(; dist, db), queries, k) refs = let - R = kcenters(dist, db, 64; initial=:rand) - refs = MatrixDatabase(R.centers) + R = references(dist, db, 64; initial=:rand) + refs = MatrixDatabase(R) ExhaustiveSearch(; dist, db=refs) end