In [1]:
import numpy as np

In [2]:
from time import time
from tqdm.notebook import tqdm

def distortion(result, x_data):
    """
    Calculate the distortion of a result.
    """
    centroids = np.array([np.mean(x_data[result[k]], axis=0) for k in range(len(result))])
    # compute distortiona
    distortion = 0
    for k in range(len(result)):
        distortion += np.sum(np.linalg.norm(x_data[result[k]] - centroids[k], axis=1, ord=2)**2)
    return distortion

def evaluate_algo(algo, dataset, n_runs=50):
    timepoints = []
    results = []

    timepoints.append(time())
    for _ in tqdm(range(n_runs)):
        result = algo.fit(dataset.x_data)
        results.append(result)
        timepoints.append(time())

    distortions = np.array([distortion(result, dataset.x_data) for result in results])
    timepoints = np.array(timepoints)
    time_since_start = timepoints[1:] - timepoints[0]
    best_result_so_far = np.minimum.accumulate(distortions)

    return distortions, time_since_start, best_result_so_far


In [3]:
from algorithms.other_algos import *
from datasets.cv_datasets import *

In [4]:
ds = MNISTDataset({})

ds.x_data = ds.x_data.astype(np.float32)

Loading MNIST dataset...
Data shape : (70000, 784)
Target shape : (70000,)
MNIST dataset loaded.


In [5]:
config_algo = {
    "k": 10
}

algos = [KMeansPlusPlusAlgorithm(config_algo), KKZ_Algorithm(config_algo), PCA_GuidedSearchAlgorithm(config_algo)]

In [6]:

for algo in algos:
    distortions, time_since_start, best_result_so_far = evaluate_algo(algo, ds, n_runs=1)
    print("Average distortion:", np.mean(distortions))
    print("Average time:", np.mean(time_since_start))
    print("Average best result so far:", np.mean(best_result_so_far))
    print("")

  0%|          | 0/1 [00:00<?, ?it/s]

Average distortion: 178432577024.0
Average time: 9.015884160995483
Average best result so far: 178432577024.0



  0%|          | 0/1 [00:00<?, ?it/s]

Average distortion: 210566150720.0
Average time: 2.80046010017395
Average best result so far: 210566150720.0



  0%|          | 0/1 [00:00<?, ?it/s]

Average distortion: 178483519488.0
Average time: 8.452669620513916
Average best result so far: 178483519488.0

