# K Nearest Neighbors

[Reference 1](https://scikit-learn.org/stable/modules/neighbors.html)  
[Reference 2](https://en.wikipedia.org/wiki/Nearest_neighbor_search)  
[Reference 3](https://booking.ai/k-nearest-neighbours-from-slow-to-fast-thanks-to-maths-bec682357ccd)

### Naive Approach  
No transforming on original dataset (No training)  
For prediction, iter through the original dataset, and find the nearest K data, by a given metrics  

Can use euclidean distance:
$$\text{dist}(X_1, X_2) = \|X_1 - X_2\|$$

A better approach is cosine similarity:
$$\text{sim}(X_1, X_2) = \frac{X_1 \cdot X_2}{\|X_1\| \|X_2\|}$$
which computes the cos value between two vectors  
1 for 0 degree, and less than 1 for $(0, \pi]$

In [1]:
include("../tools.jl")
import .JuTools

In [2]:
import Statistics
import Random
import LinearAlgebra

In [3]:
function cosine_sim(X1::Array{T} where T<:Number, X2::Array{T} where T<:Number)::AbstractFloat
    @assert size(X1) == size(X2)
    @assert ndims(X1) == ndims(X2) == 1
    product = LinearAlgebra.dot(X1, X2)
    X1_norm = LinearAlgebra.norm(X1, 2)
    X2_norm = LinearAlgebra.norm(X2, 2)
    return product / (X1_norm * X2_norm)
end

cosine_sim (generic function with 1 method)

In [4]:
X_data, Y_data = JuTools.data_generate_linear_2d()
println(size(X_data))
println(size(Y_data))

(1000, 2)
(1000,)


In [5]:
X_data[1:2, :]

2×2 Array{Float64,2}:
 74.9  35.6
 78.7  58.9

In [6]:
cosine_sim(X_data[1, :], X_data[2, :])

0.9803061697627

In [7]:
cosine_sim(X_data[1, :], X_data[3, :])

0.8586498516465619

In [8]:
# define majority vote function
function majority_vote(y::Array{T} where T<:Number)::Number
    @assert ndims(y) == 1
    unique_votes = Dict{Number, Integer}()
    for y_val in y
        if !haskey(unique_votes, y_val)
            push!(unique_votes, y_val => 1)
        else
            unique_votes[y_val] += 1
        end
    end
    result = sort(collect(unique_votes), by=m->m[2])
    return result[end][1]
end

majority_vote (generic function with 1 method)

In [9]:
majority_vote([1,1,0])

1

Output ordering is affected by input ordering

In [10]:
majority_vote([1,1,0,0])

1

In [11]:
majority_vote([1,1,0,0,0])

0

In [12]:
X_train, X_test, Y_train, Y_test = JuTools.split_data(X_data, Y_data, shuffle=true, ratio=0.3)
println(size(X_train))
println(size(X_test))
println(size(Y_train))
println(size(Y_test))

(700, 2)
(300, 2)
(700,)
(300,)


In [13]:
# define predict function, naive approach
function predict_naive(X_predict::Array{T} where T<:Number, K::Integer, X_data::Array{T} where T<:Number, Y_data::Array{T} where T<:Number)::Array
    @assert ndims(X_data) == 2
    @assert ndims(Y_data) == 1
    @assert size(X_data)[1] == size(Y_data)[1]
    @assert 0 < ndims(X_predict) <= 2
    @assert 0 < K < size(X_data)[1]
    if ndims(X_predict) < 2
        X_predict = reshape(X_predict, (1, size(X_predict)[1]))
    end
    @assert size(X_predict)[2] == size(X_data)[2]
    result = Array{Number}(undef, size(X_predict)[1])
    sim = Array{Tuple{Integer, AbstractFloat}}(undef, size(X_data)[1])
    for i in 1:size(X_predict)[1]
        vec_predict = X_predict[i, :]
        for j in 1:size(X_data)[1]
            vec_data = X_data[j, :]
            vec_similarity = cosine_sim(vec_predict, vec_data)
            sim[j] = (j, vec_similarity)
        end
        sort!(sim, by=m->m[2], rev=true)
        K_nearest_votes = Y_data[[m[1] for m in sim[1:K]]]
        result[i] = majority_vote(K_nearest_votes)
    end
    return result
end

predict_naive (generic function with 1 method)

In [14]:
Y_predict = predict_naive(X_test, 5, X_train, Y_train)

300-element Array{Number,1}:
 0.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 ⋮
 1.0
 1.0
 1.0
 1.0
 0.0
 1.0
 1.0
 1.0
 1.0
 0.0
 1.0
 1.0

In [15]:
JuTools.compute_accuracy(Y_predict, Y_test)

0.8933333333333333

In [16]:
# what about dist similarity?
function dist_sim(X1::Array{T} where T<:Number, X2::Array{T} where T<:Number)::AbstractFloat
    @assert size(X1) == size(X2)
    @assert ndims(X1) == ndims(X2) == 1
    return sqrt(sum((X1 .- X2).^2))
end

dist_sim (generic function with 1 method)

In [17]:
dist_sim(X_data[1, :], X_data[2, :])

23.60783768158363

In [18]:
dist_sim(X_data[2, :], X_data[3, :])

41.87242529398077

It's greatly affected by the scale of data!

In [19]:
function predict_naive_fun(X_predict::Array{T} where T<:Number, K::Integer, X_data::Array{T} where T<:Number, Y_data::Array{T} where T<:Number)::Array
    @assert ndims(X_data) == 2
    @assert ndims(Y_data) == 1
    @assert size(X_data)[1] == size(Y_data)[1]
    @assert 0 < ndims(X_predict) <= 2
    @assert 0 < K < size(X_data)[1]
    if ndims(X_predict) < 2
        X_predict = reshape(X_predict, (1, size(X_predict)[1]))
    end
    @assert size(X_predict)[2] == size(X_data)[2]
    result = Array{Number}(undef, size(X_predict)[1])
    sim = Array{Tuple{Integer, AbstractFloat}}(undef, size(X_data)[1])
    for i in 1:size(X_predict)[1]
        vec_predict = X_predict[i, :]
        for j in 1:size(X_data)[1]
            vec_data = X_data[j, :]
            vec_similarity = dist_sim(vec_predict, vec_data)
            sim[j] = (j, vec_similarity)
        end
        sort!(sim, by=m->m[2])
        K_nearest_votes = Y_data[[m[1] for m in sim[1:K]]]
        result[i] = majority_vote(K_nearest_votes)
    end
    return result
end

predict_naive_fun (generic function with 1 method)

In [20]:
JuTools.compute_accuracy(predict_naive_fun(X_test, 5, X_train, Y_train), Y_test)

0.9066666666666666

It produces better score because X_data has 2 dimensions, which is best fit for computing euclidean distance  
Eventually we'll be using cosine similarity in implementation

Although it (`predict_naive`) may be slow on large dataset, it is easy to implement and it works as expected

### K-Dimensional Tree (K-d tree) Approach
A space partitioning technique  
Treat each data row as a point in `k`-dimensional space    
[Wikipedia](https://en.wikipedia.org/wiki/K-d_tree)

In [21]:
mutable struct KdTree
    X_data::Array{T} where T<:Number # 1d vector
    Y_data::Number                   # number
    child_l::Union{KdTree,Nothing}
    child_r::Union{KdTree,Nothing}
end

In [22]:
# K-d tree generator function
function create_kdtree(X_data::Array{T} where T<:Number, Y_data::Array{T} where T<:Number)::KdTree
    @assert ndims(X_data) == 2
    @assert ndims(Y_data) == 1
    @assert size(X_data)[1] == size(Y_data)[1]
    function kdtree_recursive_generate(X_data::Array, Y_data::Array, depth::Integer, n_axes::Integer)::KdTree
        curr_axis = mod(depth, n_axes) + 1 # array starts from 1
        data_combined = hcat(X_data, Y_data)
        data_combined = sortslices(data_combined, by=m->m[curr_axis], dims=1)
        X_data = data_combined[:, 1:end-1]
        Y_data = data_combined[:, end]
        i_mid = div(size(X_data)[1], 2) + 1
        node_X_data = X_data[i_mid, :]
        node_Y_data = Y_data[i_mid]
        node = KdTree(node_X_data, node_Y_data, nothing, nothing)
        if i_mid > 1
            node.child_l = kdtree_recursive_generate(X_data[1:i_mid-1,:], Y_data[1:i_mid-1], depth+1, n_axes)
        end
        if i_mid < size(X_data)[1]
            node.child_r = kdtree_recursive_generate(X_data[i_mid+1:end,:], Y_data[i_mid+1:end], depth+1, n_axes)
        end
        return node
    end
    return kdtree_recursive_generate(X_data, Y_data, 0, size(X_data)[2])
end

create_kdtree (generic function with 1 method)

In [23]:
kdtree = create_kdtree(X_data, Y_data)
println(kdtree.X_data)
println(kdtree.Y_data)

[51.3, 63.2]
1.0


In [24]:
kdtree_test = create_kdtree(reshape([30,5,10,70,50,35], (6, 1)), [1,1,1,1,1,1])
println(kdtree_test)

KdTree([35], 1, KdTree([10], 1, KdTree([5], 1, nothing, nothing), KdTree([30], 1, nothing, nothing)), KdTree([70], 1, KdTree([50], 1, nothing, nothing), nothing))


In [25]:
# inspired from https://stackoverflow.com/questions/1627305/nearest-neighbor-k-d-tree-wikipedia-proof/37107030#37107030
# note that for kdtree search, we use euclidean distance
function predict_kdtree(X_predict::Array{T} where T<:Number, kdtree::KdTree; K::Integer=5)::Array
    @assert K > 0
    @assert 0 < ndims(X_predict) <= 2
    if ndims(X_predict) == 1
        X_predict = reshape(X_predict, (1, size(X_predict)[1]))
    end
    @assert size(X_predict)[2] == size(kdtree.X_data)[1]

    function kdtree_closest_max(kdtree_closest::Array{Union{KdTree, Nothing}},
                kdtree_closest_val::Array{AbstractFloat})::Tuple{Integer, AbstractFloat}
        default = (0, 0.0)
        for i in 1:size(kdtree_closest)[1]
            if kdtree_closest[i] === nothing
                break
            elseif default[1] == 0 || (kdtree_closest_val[i] > default[2])
                default = (i, kdtree_closest_val[i])
            end
        end
        return default
    end
    
    function kdtree_update_nearest!(X_vec::Array, kdtree::KdTree, kdtree_closest::Array{Union{KdTree, Nothing}},
                kdtree_closest_val::Array{AbstractFloat})
        @assert size(kdtree_closest) == size(kdtree_closest_val)
        distance = dist_sim(kdtree.X_data, X_vec)
        if nothing in kdtree_closest
            for i in 1:size(kdtree_closest)[1]
                if kdtree_closest[i] === nothing
                    kdtree_closest[i] = KdTree(kdtree.X_data, kdtree.Y_data, nothing, nothing)
                    kdtree_closest_val[i] = distance
                    break
                end
            end
        else
            curr_max = kdtree_closest_max(kdtree_closest, kdtree_closest_val)
            if distance < curr_max[2]
                kdtree_closest[curr_max[1]] = KdTree(kdtree.X_data, kdtree.Y_data, nothing, nothing)
                kdtree_closest_val[curr_max[1]] = distance
            end
        end
    end
    
    function kdtree_recursive_search!(X_vec::Array, kdtree::KdTree, depth::Integer, n_axes::Integer, 
                kdtree_closest::Array{Union{KdTree, Nothing}}, kdtree_closest_val::Array{AbstractFloat})
        @assert size(kdtree_closest) == size(kdtree_closest_val)
        # check current node
        kdtree_update_nearest!(X_vec, kdtree, kdtree_closest, kdtree_closest_val)
        # run on children
        curr_axis = mod(depth, n_axes) + 1 # array starts from 1
        if X_vec[curr_axis] < kdtree.X_data[curr_axis]
            if kdtree.child_l !== nothing
                kdtree_recursive_search!(X_vec, kdtree.child_l, depth+1, n_axes, kdtree_closest, kdtree_closest_val)
            end
            if (X_vec[curr_axis] + kdtree_closest_max(kdtree_closest, kdtree_closest_val)[2] >= kdtree.X_data[curr_axis]) && kdtree.child_r !== nothing
                kdtree_recursive_search!(X_vec, kdtree.child_r, depth+1, n_axes, kdtree_closest, kdtree_closest_val)
            end
        else
            if kdtree.child_r !== nothing
                kdtree_recursive_search!(X_vec, kdtree.child_r, depth+1, n_axes, kdtree_closest, kdtree_closest_val)
            end
            if (X_vec[curr_axis] - kdtree_closest_max(kdtree_closest, kdtree_closest_val)[2] <= kdtree.X_data[curr_axis]) && kdtree.child_l !== nothing
                kdtree_recursive_search!(X_vec, kdtree.child_l, depth+1, n_axes, kdtree_closest, kdtree_closest_val)
            end
        end
    end
    
    result = Array{Number}(undef, size(X_predict)[1])
    for i in 1:size(X_predict)[1]
        kdtree_closest = Array{Union{KdTree, Nothing}}(nothing, K)
        kdtree_closest_val = Array{AbstractFloat}(undef, K)
        kdtree_recursive_search!(X_predict[i, :], kdtree, 0, size(X_predict)[2], kdtree_closest, kdtree_closest_val)
        K_nearest_votes = Number[]
        for i in 1:K
            if kdtree_closest[i] === nothing
                break
            else
                push!(K_nearest_votes, kdtree_closest[i].Y_data)
            end
        end
        result[i] = majority_vote(K_nearest_votes)
    end
    return result
end

predict_kdtree (generic function with 1 method)

In [26]:
kdtree_train = create_kdtree(X_train, Y_train)
println(JuTools.compute_accuracy(predict_kdtree(X_test, kdtree_train, K=10), Y_test))

0.9133333333333333


### Ball Tree Approach
A better space partition approach  
More efficient than K-d Tree when searching  
[Wikipedia](https://en.wikipedia.org/wiki/Ball_tree)

In [27]:
mutable struct BallTree
    X_data::Array{T} where T<:Number
    Y_data::Number
    pivot::Union{Array{T},Nothing} where T<:Number # defines pivot point of hypersphere
    radius::AbstractFloat                            # defines radius of hypersphere
    child_l::Union{BallTree,Nothing}
    child_r::Union{BallTree,Nothing}
end

In [28]:
# inspired from https://gist.github.com/jakevdp/5216193
# ball tree generator function
function create_balltree(X_data::Array{T} where T<:Number, Y_data::Array{T} where T<:Number)::BallTree
    @assert ndims(X_data) == 2
    @assert ndims(Y_data) == 1
    @assert size(X_data)[1] == size(Y_data)[1]
    balltree = nothing
    if size(X_data)[1] == 1
        balltree = BallTree(X_data[1, :], Y_data[1], nothing, 0.0, nothing, nothing)
    else
        # find pivot
        pivot = vec(sum(X_data, dims=1)) ./ size(X_data)[1]
        # find radius
        radius = 0.0
        for i in 1:size(X_data)[1]
            X_vec = X_data[i, :]
            dist = dist_sim(pivot, X_vec)
            if dist > radius
                radius = dist
            end
        end
        # find greatest spread dimension
        d_greatest_spread = 1
        n_spread = 0.0
        for i in 1:size(X_data)[2]
            X_vec = X_data[:, i]
            current_spread = abs(maximum(X_vec) - minimum(X_vec))
            if current_spread > n_spread
                d_greatest_spread = i
                n_spread = current_spread
            end
        end
        data_combined = hcat(X_data, Y_data)
        data_combined = sortslices(data_combined, by=m->m[d_greatest_spread], dims=1)
        X_data = data_combined[:, 1:end-1]
        Y_data = data_combined[:, end]
        i_mid = div(size(X_data)[1], 2) + 1
        node_X_data = X_data[i_mid, :]
        node_Y_data = Y_data[i_mid]
        balltree = BallTree(node_X_data, node_Y_data, pivot, radius, nothing, nothing)
        if i_mid > 1
            balltree.child_l = create_balltree(X_data[1:i_mid-1,:], Y_data[1:i_mid-1])
        end
        if i_mid < size(X_data)[1]
            balltree.child_r = create_balltree(X_data[i_mid+1:end,:], Y_data[i_mid+1:end])
        end
    end
    return balltree
end

create_balltree (generic function with 1 method)

In [29]:
balltree = create_balltree(X_data, Y_data)
println(balltree.X_data)
println(balltree.Y_data)
println(balltree.pivot)
println(balltree.radius)

[69.1, 50.5]
1.0
[50.810399999999994, 50.01620000000001]
69.01783197551197


In [30]:
balltree_test = create_balltree(reshape([30,5,10,70,50,35], (6, 1)), [1,1,1,1,1,1])
println(balltree_test)

BallTree([35], 1, [33.333333333333336], 36.666666666666664, BallTree([10], 1, [15.0], 15.0, BallTree([5], 1, nothing, 0.0, nothing, nothing), BallTree([30], 1, nothing, 0.0, nothing, nothing)), BallTree([70], 1, [60.0], 10.0, BallTree([50], 1, nothing, 0.0, nothing, nothing), nothing))


In [31]:
# search function for ball tree
# similar to kdtree search
function predict_balltree(X_predict::Array{T} where T<:Number, balltree::BallTree; K::Integer=5)::Array
    @assert K > 0
    @assert 0 < ndims(X_predict) <= 2
    if ndims(X_predict) == 1
        X_predict = reshape(X_predict, (1, size(X_predict)[1]))
    end
    @assert size(X_predict)[2] == size(balltree.X_data)[1]
    
    function balltree_closest_max(balltree_closest::Array{Union{BallTree, Nothing}},
                balltree_closest_val::Array{AbstractFloat})::Tuple{Integer, AbstractFloat}
        # find the maximum value in balltree_closest_val
        # return its index and value
        default = (0, 0.0)
        for i in 1:size(balltree_closest)[1]
            if balltree_closest[i] === nothing
                break
            elseif default[1] == 0 || (balltree_closest_val[i] > default[2])
                default = (i, balltree_closest_val[i])
            end
        end
        return default
    end
    
    function balltree_update_nearest!(X_vec::Array, balltree::BallTree, balltree_closest::Array{Union{BallTree, Nothing}},
                balltree_closest_val::Array{AbstractFloat})
        # update current node by distance
        @assert size(balltree_closest) == size(balltree_closest_val)
        distance = dist_sim(balltree.X_data, X_vec)
        if nothing in balltree_closest
            for i in 1:size(balltree_closest)[1]
                if balltree_closest[i] === nothing
                    balltree_closest[i] = balltree
                    balltree_closest_val[i] = distance
                    break
                end
            end
        else
            curr_max = balltree_closest_max(balltree_closest, balltree_closest_val)
            if distance < curr_max[2]
                balltree_closest[curr_max[1]] = balltree
                balltree_closest_val[curr_max[1]] = distance
            end
        end
    end
    
    function balltree_recursive_search!(X_vec::Array, balltree::BallTree, balltree_closest::Array{Union{BallTree, Nothing}},
            balltree_closest_val::Array{AbstractFloat})
        # recursively search a balltree for K nearest neighbors
        @assert size(balltree_closest) == size(balltree_closest_val)
        if (!(nothing in balltree_closest) && (balltree.pivot !== nothing)
                && (dist_sim(X_vec, balltree.pivot) - balltree.radius >= 
                    balltree_closest_max(balltree_closest, balltree_closest_val)[2]))
            return nothing
        end
        # check current node
        balltree_update_nearest!(X_vec, balltree, balltree_closest, balltree_closest_val)
        # run on children
        if balltree.child_l === nothing || balltree.child_r === nothing
            if balltree.child_l !== nothing
                balltree_recursive_search!(X_vec, balltree.child_l, balltree_closest, balltree_closest_val)
            end
            if balltree.child_r !== nothing
                balltree_recursive_search!(X_vec, balltree.child_r, balltree_closest, balltree_closest_val)
            end
        else
            dist_left = dist_sim(X_vec, balltree.child_l.X_data)
            dist_right = dist_sim(X_vec, balltree.child_r.X_data)
            if dist_left < dist_right
                balltree_recursive_search!(X_vec, balltree.child_l, balltree_closest, balltree_closest_val)
                balltree_recursive_search!(X_vec, balltree.child_r, balltree_closest, balltree_closest_val)
            else
                balltree_recursive_search!(X_vec, balltree.child_r, balltree_closest, balltree_closest_val)
                balltree_recursive_search!(X_vec, balltree.child_l, balltree_closest, balltree_closest_val)
            end
        end
    end
    
    result = Array{Number}(undef, size(X_predict)[1])
    for i in 1:size(X_predict)[1]
        balltree_closest = Array{Union{BallTree, Nothing}}(nothing, K)
        balltree_closest_val = Array{AbstractFloat}(undef, K)
        balltree_recursive_search!(X_predict[i, :], balltree, balltree_closest, balltree_closest_val)
        K_nearest_votes = Number[]
        for i in 1:K
            if balltree_closest[i] === nothing
                break
            else
                push!(K_nearest_votes, balltree_closest[i].Y_data)
            end
        end
        result[i] = majority_vote(K_nearest_votes)
    end
    return result
end

predict_balltree (generic function with 1 method)

In [32]:
balltree_train = create_balltree(X_train, Y_train)
println(JuTools.compute_accuracy(predict_balltree(X_test, balltree_train, K=20), Y_test))

0.9033333333333333
