# MNIST all_pairs_distance matrix and K-means clustering

Datasets used in this notebook:


1. X2.bin (binary Float32, 219,520,000 bytes)

    70,000 digit (gray) images stored sequentially and each digit image is stored in 28x28 Float32s (in little endian), top-down and row-by-row. 


2. Y2_int8.bin (binary Int8, 70,000 bytes)

    70,000 Int8 labels of the corresponding digit images.
    

## Read MNIST binary digit image: X2.bin 

X2.bin (219,520,000 bytes)

70,000 digit (gray) images stored sequentially and each digit image is stored in 28x28 Float32s (in little endian), top-down and row-by-row. Each gray value is between 0 and 1.

In [None]:
using Images, Colors

In [None]:
pwd()

In [None]:
cd("MNIST")

In [None]:
readdir()

In [None]:
# requires 219,520,000 bytes to store 70,000 digit images
digits = Array{Float32, 2}(undef, 28*28, 70_000) 
read!("X2.bin", digits);

## Display digit images

In [None]:
digit = reshape(digits[:, 34], 28, 28);

In [None]:
typeof(digit), size(digit)

In [None]:
digit = Matrix{Gray{N0f8}}(digit)

In [None]:
digit'

In [None]:
Matrix{Gray{N0f8}}(reshape(digits[:, 100], 28, 28))'

## Read MNIST binary labels : Y2_int8.bin

In [None]:
labels = Vector{Int8}(undef, 70_000) # requires 70,000 bytes
read!("Y2_int8.bin", labels);

In [None]:
findall(==(9), labels[1:200]);

## Plot histrogram of digit ditribution

In [None]:
using Plots

In [None]:
histogram(labels, label="digits", bins=11, color=:gray)
xlabel!("digit")

In [None]:
histogram(labels, label="digits", bins=11, normalize=true, color=:gray)

In [None]:
labels_digits = []
for i in 0:9 
    push!(labels_digits, Vector{Int32}(findall(==(i), labels)))
end

In [None]:
sum(length.(labels_digits))

# Compute all-pairs distance matrix of MNIST images


In [None]:
# Define distance function of two float32 vector of size len
function dist(x, y, len)
    
    acc = zero(Float32)
    @simd for i in 1:len
        residue = x[i] - y[i]
        acc += residue * residue
    end
    
    return acc
end

In [None]:
"""
    compute all_pair_distance matrix of ndigit's MNIST digits images   
    all_pairs_dist(ndigits)

"""
function all_pairs_dist(ndigit, len)
    
    distMat = Matrix{Float32}(undef, ndigit, ndigit)

    for i in 1:ndigit
        for j in i:ndigit
            distMat[i, j] = dist(digits[:,i], digits[:,j], len)
            distMat[j, i] = distMat[i, j]
        end
    end
    
    return distMat
end

# k-means clustering of N MNIST images using distM distance matrix


## Clusters Initialization functions

In [None]:
"""
    Random initial clusters allocation

"""
rand_cluster(n_cluster, n_obj) = Vector{Int8}([rand(1:n_cluster) 
                                              for i in 1:n_obj])

In [None]:
rand_cluster(10, 50);

In [None]:
"""
    Even initial clusters allocation

"""
function even_cluster(n_cluster, n_obj)
    cluster = Vector{Int8}(undef, n_obj)
    n_per_cluster = n_obj ÷ n_cluster # integer division
    for i in 1:n_cluster
        start = (i-1) * n_per_cluster + 1
        fin = start + n_per_cluster - 1
        for j in start:fin
            cluster[j] = i
        end
    end
    # adjust the last cluster
    start = n_cluster * n_per_cluster+1
    if n_obj ≥ start
        for i in start:n_obj
            cluster[i] = i-start+1
        end
    end
    return cluster
end

In [None]:
even_cluster(5, 22)

In [None]:
"""
    Initial clusters allocation using given labels

"""
function label_cluster(n_cluster, n_obj)
    
    cluster = Vector{Int8}(undef, n_obj)
    
    for obj in 1:n_obj
        cluster[obj] =labels[obj]+1
    end
    
    return cluster
end

## Clusters Initialization functions test

# K-means functions

In [None]:
"""
    K-means clustering
    k-means!(niter, n_cluster, n_obj, cluster, dist)
        n_cluster is number of clusters
        n_obj is the number of objects to be clustered
        cluster is of Vector{Int8}(n_obj)
        dist is a Matrix{Float32}(n_obj, n_obj) : 
            distance function between any two objects

"""
function  k_means!(niter, n_cluster, n_obj, cluster, dist)
    
    mean_vec = Vector{Float32}(undef, n_cluster)
    cluster_new = Vector{Int8}(undef, n_obj) 

    for n in 1:niter
        
       # for each object find the closest cluster 
       for obj in 1:n_obj     
            for i in 1:n_cluster
                v = findall(==(i), cluster)
                le = length(v)
                mean_vec[i] = 10000f0 # typemax(Float32)  
                if le > 0 # make i-cluster is not empty
                    mean_vec[i] = mean([dist[obj,v[j]] for j in 1:le])
                end  
            end

            # set obj's new cluster index
            cluster_new[obj] = findmin(mean_vec)[2] 
        end
        
        # Update cluster
        if  cluster == cluster_new return n
        end
        
        for obj in 1:n_obj
            cluster[obj] = cluster_new[obj]
        end       
    end
    return 0
end

# Simple K means clustering tests

## Simple test 1

In [None]:
N = 8
ncluster = 4
distM = Matrix{Float32}(fill(100f0, 8, 8));

In [None]:
for i in 1:N
    distM[i,i] = 0
end

for i in 1:4
    for j in 1:4
        if i != j distM[i,j] = 1 end
    end
end

for i in 5:7
    for j in 5:7
        if i != j distM[i,j] = 1 end
    end
end


In [None]:
cluster = [1,2,1,3,2,3,4,4];

In [None]:
k_means!(10, ncluster, N, cluster, distM)

In [None]:
findall(==(1), cluster)

In [None]:
findall(==(2), cluster)

In [None]:
findall(==(3), cluster)

In [None]:
findall(==(4), cluster)

## Simple test 2

In [None]:
N = 8
ncluster = 3
distM = distM = Matrix{Float32}(fill(100f0, N, N));

In [None]:
for i in 1:N
    distM[i,i] = 0
end

for i in 1:4
    for j in 1:4
        if i != j distM[i,j] = 1 end
    end
end

for i in 5:7
    for j in 5:7
        if i != j distM[i,j] = 1 end
    end
end


In [None]:
# cluster = [2,2,3,1,1,2,3,1];
cluster = Vector{Int8}([rand(1:ncluster) 
                    for i in 1:N])

In [None]:
k_means!(10, ncluster, N, cluster, distM)

In [None]:
findall(==(1), cluster)

In [None]:
findall(==(2), cluster)

In [None]:
findall(==(3), cluster)

# K means clustering tests using MNIST Dataset

In [None]:
using Statistics

In [None]:
len = 28 * 28
N = 200
ncluster = 10

In [None]:
@time "$N digits time:" distM = all_pairs_dist(N, len);

In [None]:
sum([distM[i,j] for i in 1:N for j in i:N])

In [None]:
𝛍 = Vector{Float32}(undef, ncluster);

In [None]:
cluster = rand_cluster(ncluster, N);
# cluster = label_cluster(ncluster, N);
#cluster = even_cluster(ncluster, N);

In [None]:
for i in 1:ncluster
    print(length(findall(==(i), cluster)))
    print(" ")
end

In [None]:
for i in 1:ncluster
    v = findall(==(i), cluster)
    le = length(v)
    if le > 0
        𝛍[i] = mean([distM[v[j],v[k]] for j in 1:le for k in j:le])
    else  
        𝛍[i] = 0
    end
end

In [None]:
sum(𝛍), mean(𝛍)

In [None]:
@time k_means!(100, ncluster, N, cluster, distM)

In [None]:
for i in 1:ncluster
    print(length(findall(==(i), cluster)))
    print(" ")
end

In [None]:
for i in 1:ncluster
    v = findall(==(i), cluster)
    le = length(v)
    if le > 0
        𝛍[i] = mean([distM[v[j],v[k]] for j in 1:le for k in j:le])
    else  
        𝛍[i] = 0
    end
end

In [None]:
sum(𝛍), mean(𝛍)

In [None]:
@time k_means!(10, ncluster, N, cluster, distM)

In [None]:
for i in 1:ncluster
    println(length(findall(==(i), cluster)))
end

In [None]:
s = 0
for i in 1:ncluster
    s += length(findall(==(i), cluster))
end
s

In [None]:
for i in 1:ncluster
    v = findall(==(i), cluster)
    le = length(v)
    if le > 0
        𝛍[i] = mean([distM[v[j],v[k]] for j in 1:le for k in j:le])
    else  
        𝛍[i] = 0
    end
end

In [None]:
sum(𝛍), mean(𝛍)

In [None]:
@time k_means!(100, ncluster, N, cluster, distM)

In [None]:
for i in 1:ncluster
    println(length(findall(==(i), cluster)))
end

In [None]:
for i in 1:ncluster
    v = findall(==(i), cluster)
    le = length(v)
    if le > 0
        𝛍[i] = mean([distM[v[j],v[k]] for j in 1:le for k in j:le])
    else  
        𝛍[i] = 0
    end
end

In [None]:
sum(𝛍), mean(𝛍)

In [None]:
@time k_means!(100, ncluster, N, cluster, distM)

In [None]:
for i in 1:ncluster
    println(length(findall(==(i), cluster)))
end

# Parallel computation of all-pairs distance matrix

In [None]:
len = 28 * 28
N = 3000
@time "$N digits time:" distM = all_pairs_dist(N, len);

In [None]:
using Base.Threads

In [None]:
function all_pairs_dist_threads(ndigit, len)
    
    distMat = Matrix{Float32}(undef, ndigit, ndigit)

    @threads for i in 1:ndigit
        for j in i:ndigit
            distMat[i, j] = dist(digits[:,i], digits[:,j], len)
            distMat[j, i] = distMat[i, j]
        end
    end
    
    return distMat
end

In [None]:
len = 28 * 28
N = 3000
@time "$N digits time:" distM2 = all_pairs_dist_threads(N, len);

In [None]:
distM == distM2

## Computation of all-pairs distance matrix is very computation intensive !!

The computational complexity is O(N² X len).

In [None]:
# This distM requires about 19GB memory space and takes more than 136 mins 
# to compute
N = 70_000
len = 28*28
@time "$N digits time:" distM = all_pairs_dist(N, len);