# Tests KNN

This notebook is quite useless for the `siamese` itself but it was my testing notebook for the `K-NN` optimization. Therefore I think it can berelevant to leave it if you want to further improve it or make tests ot add new features

The `KNN` is my old numpy implementation while `TFKNN` is the one used in `distance/knn.py` which is optimized in pure `tensorflow` with `knn` core decision rule as `tf.function` decorated

Also I find it quite impressive to see that the tensorflow version is 10 times faster than the `numpy` implementation !

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf

from tqdm import tqdm

from utils.plot_utils import plot_embedding
from utils.thread_utils import ThreadPool
from utils.embeddings import load_embedding, embeddings_to_np
from utils.distance.distance_method import distance

class KNN(object):
    def __init__(self, embeddings, ids = None, k = 5, use_mean = False, 
                 method = 'euclidian', ** kwargs):
        if isinstance(embeddings, str):
            embeddings = load_embedding(embeddings)
        
        if isinstance(embeddings, pd.DataFrame):
            ids = embeddings['id'].values
            embeddings = embeddings_to_np(embeddings)
        
        self.ids    = np.array(ids)
        self.embeddings = embeddings
        
        self.k          = k if not use_mean else 1
        self.method     = method
        self.use_mean   = use_mean
        
        self.mean_ids, self.mean_embeddings = self.get_mean_embeddings()

    def get_mean_embeddings(self, embeddings = None, ids = None):
        if embeddings is None:
            embeddings, ids = self.embeddings, self.ids
        
        uniques = np.unique(ids)
        return uniques, np.array([
            np.mean(embeddings[self.ids == unique_id], axis = 0)
            for unique_id in uniques
        ])
    
    def get_embeddings(self, ids = None, use_mean = False):
        if ids is not None and not isinstance(ids, (list, tuple, np.ndarray)): ids = [ids]
        if use_mean:
            res_embeddings, res_ids = self.mean_embeddings, self.mean_ids
        else:
            res_embeddings, res_ids = self.embeddings, self.ids
        
        if ids is not None:
            indexes = np.array([
                id_i in ids for id_i in res_ids
            ])
            res_embeddings = res_embeddings[indexes]
            res_ids = res_ids[indexes]
        
        return res_embeddings, res_ids
    
    def distance(self, x, ids = None, use_mean = False):
        embeddings, ids = self.get_embeddings(ids, use_mean)
        return distance(x, embeddings, method = self.method), ids
        
    def predict(self, x, possible_ids = None, k = None, use_mean = None, plot = False, ** kwargs):
        if isinstance(x, tf.Tensor): x = x.numpy()
        assert isinstance(x, np.ndarray) and x.ndim in (1, 2)

        if use_mean is None: use_mean = self.use_mean
        if use_mean: k = 1
        elif not k: k = self.k
        
        if possible_ids is not None:
            if not isinstance(possible_ids, (list, tuple, np.ndarray)): 
                possible_ids = [possible_ids]
            
        
        if x.ndim == 2:
            if possible_ids is None: possible_ids = [None] * len(x)
            elif len(possible_ids) != len(x):
                possible_ids = [possible_ids] * len(x)

            assert len(possible_ids) == len(x)
            
            pool = ThreadPool(target = self.predict)
            for xi, ids_i in zip(x, possible_ids):
                pool.append(kwargs = {
                    'x' : xi, 'possible_ids' : ids_i, 'plot' : False,
                    'k' : k, 'use_mean' : use_mean, ** kwargs
                })
            pool.start(tqdm = lambda x: x)
            
            pred = np.array(pool.result())
            #pred = np.array([
            #    self.predict(xi, ids_i, k = k, use_mean = use_mean, plot = False, ** kwargs) 
            #    for xi, ids_i in zip(x, possible_ids)
            #])
        else:
            if possible_ids is not None and len(possible_ids) == 0: return -1
            
            distance, ids = self.distance(x, ids = possible_ids, use_mean = use_mean)

            k_nearest_idx = np.argsort(distance)[:k]

            nearest_ids = {}
            for nearest_id in ids[k_nearest_idx]:
                nearest_ids.setdefault(nearest_id, 0)
                nearest_ids[nearest_id] += 1

            best_id, n = [], 0
            for nearest_id, n_times in nearest_ids.items():
                if n_times > n: best_id, n = [nearest_id], n_times
                elif n_times == n: best_id.append(nearest_id)

            pred = best_id[0] if len(best_id) == 1 else -2
        
        if plot:
            self.plot(x, pred, ** kwargs)
            
        return pred
    
    def plot(self, x = None, x_ids = None, marker_kwargs = None, ** kwargs):
        if marker_kwargs is None: marker_kwargs = {}

        # Original points
        embeddings, ids = self.embeddings, self.ids
        marker = ['o'] * len(embeddings)
        
        # Means as big points
        embeddings = np.concatenate([embeddings, self.mean_embeddings], axis = 0)
        ids = np.concatenate([ids, self.mean_ids], axis = 0)
        
        marker += ['O'] * len(self.mean_ids)
        marker_kwargs.setdefault('O', {
            'marker'    : 'o',
            'linewidth' : kwargs.get('linewidth', 2.5) * 3
        })
        
        # New data points to plot
        if x is not None:
            if isinstance(x, pd.DataFrame):
                if 'id' in x and x_ids is None:
                    x_ids = x['id'].values
                x = embeddings_to_np(x)
            
            if x.ndim == 1: x = np.expand_dims(x, 0)
            if x_ids is not None:
                if not isinstance(x_ids, (list, tuple, np.ndarray)): x_ids = [x_ids]
                x_ids = np.array(x_ids)
            else:
                fake_id = 0
                while fake_id in ids: fake_id += 1
                x_ids = np.array([fake_id] * len(x))
                marker_kwargs.setdefault('x', {'c' : 'w'})
            
            assert len(x_ids) == len(x)
            
            embeddings = np.concatenate([embeddings, x], axis = 0)
            ids = np.concatenate([ids, x_ids], axis = 0)
            marker += ['x'] * len(x)
        
        plot_embedding(
            embeddings, ids = ids, marker = np.array(marker), 
            marker_kwargs = marker_kwargs, ** kwargs
        )



In [2]:
import numpy as np
import pandas as pd
import tensorflow as tf

from utils.thread_utils import ThreadPool
from utils.plot_utils import plot_embedding
from utils.embeddings import load_embedding, compute_mean_embeddings, embeddings_to_np
from utils.distance.distance_method import distance

class TFKNN(object):
    """
        Tensorflow implementation of the `K-Nearest Neighbors` algorithm
        
        It also has some additional features such as : 
            - Plotting embeddings / predictions
            - Use a `use_mean` version where the prediction is the nearest `centroid`*
            
        * A `centroid` is the mean point of all points belonging to a given label
    """
    def __init__(self, embeddings, ids = None, k = 5, use_mean = False, 
                 method = 'euclidian', ** kwargs):
        """
            Constructor for the KNN class
            
            Arguments : 
                - embeddings    : the embeddings to use as labelled points
                    If str  : call `load_embedding()` on it
                    If pd.DataFrame : use the 'id' column for `ids` and call `embeddings_to_np` on it
                    Else    : must be a np.ndarray / tf.Tensor 2D matrix
                - ids   : ids of the embeddings (if embeddings is a DataFrame, ids = embeddings['id'].values)
                - k / use_mean  : default configuration for the `predict` method
                - method        : distance method to use
        """
        if isinstance(embeddings, str):
            embeddings = load_embedding(embeddings)
        
        if isinstance(embeddings, pd.DataFrame):
            ids = embeddings['id'].values
            embeddings = embeddings_to_np(embeddings)
        
        assert len(embeddings) == len(ids)
        
        self.ids    = np.array(ids)
        self.embeddings = tf.cast(embeddings, tf.float32)
        
        self.k          = tf.cast(k, dtype = tf.int32)
        self.use_mean   = use_mean
        self.method     = method
        
        self.mean_ids, self.mean_embeddings = self.get_mean_embeddings()
    
    def get_mean_embeddings(self):
        """ Compute the mean embeddings for each id """
        return compute_mean_embeddings(self.embeddings, self.ids)
    
    def get_embeddings(self, ids = None, use_mean = False):
        """ Return all embeddings from specified ids """
        if ids is not None and not isinstance(ids, (list, tuple, np.ndarray, tf.Tensor)): ids = [ids]
        if use_mean:
            embeddings, res_ids = self.mean_embeddings, self.mean_ids
        else:
            embeddings, res_ids = self.embeddings, self.ids
        
        if ids is not None:
            indexes = tf.concat([
                tf.where(res_ids == id_i) for id_i in ids
            ], axis = 0)
            embeddings = tf.gather(embeddings, indexes)
            res_ids = tf.gather(res_ids, indexes)
        
        return embeddings, res_ids
    
    def distance(self, x, ids = None, use_mean = False):
        """ Compute distance between x and embeddings for given ids """
        embeddings, ids = self.get_embeddings(ids, use_mean)
        return distance(tf.cast(x, tf.float32), embeddings, method = self.method), ids
    
    def predict(self, x, possible_ids = None, k = None, use_mean = None,
                plot = False, tqdm = lambda x: x, ** kwargs):
        """
            Predict ids for each `x` vector based on the `k-nn` decision procedure
            
            Arguments :
                - x : the 1D / 2D matrix of embeddings vector(s) to predict label
                - possible_ids  : a list of `possible ids` (other ids are not taken into account for the k-nn)
                - k / use_mean  : k-nn metaparameter (if not provided use self.k / self.use_mean)
                - tqdm  : progress bar if `x` is a matrix
                - plot / kwargs : whether to plot the prediction result or not
            
            If x is a matrix, call `self.predict` for each vector in the matrix in a multi-threaded way
            It allows to achieve really good performances even for prediction on a large dataset
        """
        if use_mean is None: use_mean = self.use_mean
        if use_mean: k = 1
        elif k is None: k = self.k
        else: k = tf.cast(k, tf.int32)
        
        if possible_ids is not None and not isinstance(possible_ids, (list, tuple, np.ndarray)):
            possible_ids = [possible_ids]
        
        x = tf.cast(x, tf.float32)
        if tf.rank(x) == 2:
            if possible_ids is None: possible_ids = [None] * len(x)
            elif len(possible_ids) != len(x):
                possible_ids = [possible_ids] * len(x)

            assert len(possible_ids) == len(x)
            
            pool = ThreadPool(target = self.predict)
            for xi, ids_i in zip(x, possible_ids):
                pool.append(kwargs = {
                    'x' : xi, 'possible_ids' : ids_i, 'k' : k, 'use_mean' : use_mean
                })
            pool.start(tqdm = tqdm)
            
            pred = tf.concat(pool.result(), axis = 0)
        else:
            embeddings, ids = self.get_embeddings(possible_ids, use_mean)
            
            pred = knn(x, embeddings, ids, k, self.method)
        
        if plot:
            self.plot(x, pred, ** kwargs)
        
        return pred
        
    def plot(self, x = None, x_ids = None, marker_kwargs = None, ** kwargs):
        """
            Plot the labelled datasets + centroids + possible `x` to predict (with their predicted labels) 
        """
        if marker_kwargs is None: marker_kwargs = {}

        # Original points
        embeddings, ids = self.embeddings, self.ids
        marker = ['o'] * len(embeddings)
        
        # Means as big points
        embeddings = np.concatenate([embeddings, self.mean_embeddings], axis = 0)
        ids = np.concatenate([ids, self.mean_ids], axis = 0)
        
        marker += ['O'] * len(self.mean_ids)
        marker_kwargs.setdefault('O', {
            'marker'    : 'o',
            'linewidth' : kwargs.get('linewidth', 2.5) * 3
        })
        
        # New data points to plot
        if x is not None:
            if isinstance(x, pd.DataFrame):
                if 'id' in x and x_ids is None:
                    x_ids = x['id'].values
                x = embeddings_to_np(x)
            
            if x.ndim == 1: x = np.expand_dims(x, 0)
            if x_ids is not None:
                if not isinstance(x_ids, (list, tuple, np.ndarray)): x_ids = [x_ids]
                x_ids = np.array(x_ids)
            else:
                fake_id = 0
                while fake_id in ids: fake_id += 1
                x_ids = np.array([fake_id] * len(x))
                marker_kwargs.setdefault('x', {'c' : 'w'})
            
            assert len(x_ids) == len(x)
            
            embeddings = np.concatenate([embeddings, x], axis = 0)
            ids = np.concatenate([ids, x_ids], axis = 0)
            marker += ['x'] * len(x)
        
        plot_embedding(
            embeddings, ids = ids, marker = np.array(marker), 
            marker_kwargs = marker_kwargs, ** kwargs
        )

@tf.function(experimental_relax_shapes = True)
def knn(x, embeddings, ids, k, distance_metric):
    """
        Compute the k-nn decision procedure for a given x based on a list of labelled embeddings
        
        Return the majoritary id in the `k` nearest neigbors or `-2` if there is an equality
    """
    distances = tf.squeeze(distance(x, embeddings, method = distance_metric))
    
    k_nearest_val, k_nearest_idx = tf.nn.top_k(-distances, k)
    
    nearest_ids = tf.cast(tf.gather(ids, k_nearest_idx), tf.int32)
    counts = tf.math.bincount(nearest_ids)

    nearest_ids = tf.squeeze(tf.where(counts == tf.reduce_max(counts)))

    return tf.cast(nearest_ids, tf.int32) if tf.rank(nearest_ids) == 0 else -2


In [3]:
import time

def predict(knn, embeddings):
    t0 = time.time()
    #pred = knn.get_embeddings(list(range(7)))
    pred = knn.predict(embeddings, possible_ids = list(range(7)))
    end = time.time() - t0
    print("Prediction takes {:.3f} sec !".format(end))
    return pred

n = 25000

embeddings = np.random.random(size = (n, 128))
ids = np.random.randint(0, 10, size = (n,))

val_embeddings = np.random.random(size = (2500, 128))


np_knn = KNN(embeddings, ids)
tf_knn = TFKNN(embeddings, ids)

np_pred = predict(np_knn, val_embeddings)
tf_pred = predict(tf_knn, val_embeddings)
print(np_pred[:5])
print(tf_pred[:5])
print(np.all(np_pred == tf_pred))

Prediction takes 78.444 sec !
Prediction takes 8.078 sec !
[-2 -2 -2 -2 -2]
tf.Tensor([-2 -2 -2 -2 -2], shape=(5,), dtype=int32)
True
