# NearestNeighbors.jl

In [154]:
import Pkg
Pkg.activate(@__DIR__)

[32m[1m  Activating[22m[39m environment at `~/JuliaProjects/KPLMCenters.jl/Project.toml`


In [155]:
using KPLMCenters
using Random
using LinearAlgebra
using Statistics
using Distances
using BenchmarkTools

In [156]:
using NearestNeighbors

In [157]:
rng = MersenneTwister(1234)
signal = 10^4
noise = 10^3
dimension = 3
noise_min = -7
noise_max = 7
σ = 0.01
points = infinity_symbol(rng, signal, noise, σ, dimension, noise_min, noise_max)
points_matrix = hcat(points...)
size(points_matrix)

(3, 11000)

In [158]:
function update_means_and_weights_1( points, k)
    
    n_points = length(points)
    n_centers = 100
    centers = [points[i] for i in 1:n_centers]
    μ = [zeros(3) for i in 1:n_centers]
    weights = zeros(n_centers)
    Σ = [diagm(ones(3)) for i in 1:n_centers]
    dists = zeros(n_points)
    idxs = zeros(Int, n_points)

    for i = 1:n_centers

        invΣ = inv(Σ[i])
        for (j, x) in enumerate(points)
            dists[j] = mahalanobis(x, μ[i], invΣ)
        end
   
        idxs .= sortperm(dists)
        μ[i] .= mean(view(points, idxs[1:k]))
        weights[i] = mean([mahalanobis(points[j], μ[i], inv(Σ[i])) for j in idxs[1:k]]) + log(det(Σ[i]))

    end
    
    return μ, weights
    
end


update_means_and_weights_1 (generic function with 1 method)

In [159]:
@time update_means_and_weights_1( points, 20);

  0.511285 seconds (2.68 M allocations: 271.496 MiB, 3.43% gc time, 46.27% compilation time)


In [160]:
function update_means_and_weights_2( points_matrix, k)
    
    @show dimension, n_points = size(points_matrix)
    n_centers = 100
    centers = [points_matrix[:,i] for i in 1:n_centers]
    μ = [zeros(3) for i in 1:n_centers]
    weights = zeros(n_centers)
    Σ = [diagm(ones(3)) for i in 1:n_centers]
    
    for i = 1:n_centers
        
        invΣ = inv(Σ[i])
        metric = Mahalanobis(invΣ)
        balltree = BallTree(points_matrix, metric)
        idxs, dists = knn(balltree, centers[i], k)
        μ[i] .= vec(mean(points_matrix[:, idxs[1]], dims=2))
        weights[i] = mean([sqmahalanobis(points[j], μ[i], invΣ) for j in idxs[1]]) + log(det(Σ[i]))

    end
    
    return μ, weights
    
end

update_means_and_weights_2 (generic function with 1 method)

In [161]:
@time update_means_and_weights_2( points_matrix, 20);

(dimension, n_points) = size(points_matrix) = (3, 11000)
  0.735302 seconds (1.85 M allocations: 256.059 MiB, 2.78% gc time, 21.51% compilation time)


# Distances.jl

Some examples of optimized function of the package

Does not work for us because the metric changes for every center

In [162]:
Σ = Matrix(I, 3, 3)
metric = SqMahalanobis(Σ)
n_points = 10^4
n_centers = 10
X = randn(3, n_points)
Y = rand(3, n_centers)

@btime r = pairwise(metric, X, Y);

  583.090 μs (19 allocations: 1.07 MiB)


In [163]:
dists = zeros(n_points,n_centers)
function with_loop!( dists, X, Y)
    n_centers = size(Y)[2]
    for i in 1:n_centers
        center = Y[:,i]
        dists[:,i] = [sqmahalanobis(x, center, Σ) for x in eachcol(X)]
    end
    return dists
end
@btime with_loop!( dists, X, Y);

  15.883 ms (400060 allocations: 28.23 MiB)


In [164]:
@btime pairwise!(dists, metric, X, Y)

  543.430 μs (17 allocations: 313.83 KiB)


10000×10 Matrix{Float64}:
 17.9395    17.2349    18.0525    …  16.3431    10.8027    11.4172
  3.09788    2.89455    2.55896       2.55415    2.15282    2.0233
  1.3579     0.731615   1.53942       1.13611    1.35744    1.31772
  6.1911     4.51484    6.90762       5.33457    2.68303    3.17039
  0.401943   0.240069   0.861076      0.573069   1.83388    1.66919
  3.79854    2.62916    5.26274   …   3.79537    3.48451    3.76875
  1.61506    2.07164    2.20385       1.81358    2.66264    2.498
  0.223735   0.804479   0.11112       0.237399   1.19482    0.892705
  1.94499    1.9664     2.70217       1.91895    1.81659    1.83466
  1.566      2.67946    0.874114      1.4829     2.52658    2.06435
  3.32212    2.13003    4.44869   …   2.9842     1.71702    2.05902
  2.84012    3.76298    2.99067       3.60504    7.22948    6.62796
  7.31387    6.49664    7.4125        6.21924    2.91625    3.28581
  ⋮                               ⋱                        
 20.2946    17.5762    23.076    