In [None]:
from init_tree import init_tree_by_name
from load_datasets import load_by_name

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances

In [None]:
dataset_name = '20newsgroups' # See the load_by_name function for the available datasets
path = None # Necessary for some datasets

In [None]:
data, clustering_true = load_by_name(dataset_name, path)
k = np.unique(clustering_true).size

# Helper functions

In [None]:
def kmeans_assignment(data, labels, centers):
    assert data.shape[0] == labels.size
    assert data.shape[1] == centers.shape[1]
    map_label_center = {}
    cost = 0
    distances = pairwise_distances(data, centers, metric='sqeuclidean')
    assert distances.shape[0] == data.shape[0]
    assert distances.shape[1] == centers.shape[0]

    unique_labels = np.unique(labels)
    for label in unique_labels:
        distances_for_cluster = distances[labels == label, :]
        distances_per_center = distances_for_cluster.sum(axis=0)
        map_label_center[label] = distances_per_center.argmin()
        cost += distances_per_center.min()

    return np.array([map_label_center[label] for label in labels]), cost

In [None]:
def price_of_explainability(data, y_pred, y_true, centers):
    return kmeans_assignment(data, y_pred, centers)[1] / kmeans_assignment(data, y_true, centers)[1]

# Train

In [None]:
results = {}

clique_baselines = {}

def add_result(name, prediction, baseline, centers):
    results[name] = [price_of_explainability(data, prediction, baseline, centers)]

In [None]:
kmeans = KMeans(n_clusters=k, random_state=1234)
clustering_kmeans = kmeans.fit_predict(data)
clique_baselines['KMeans'] = clustering_kmeans

In [None]:
centers = np.zeros((k, data.shape[1]))
for i in range(k):
  centers[i,:] = data[clustering_kmeans == i, :].mean(axis=0)

In [None]:
add_result('K Means', clustering_kmeans, clustering_kmeans, centers)

## SpEx Clique

In [None]:
for name, reference in clique_baselines.items():
    clique_global = init_tree_by_name('clique')
    clique_global.train(data, reference)
    y_pred = clique_global.predict(data)
    add_result('SpEx Clique', y_pred, clustering_kmeans, centers)

## EMN

In [None]:
emn = init_tree_by_name('emn')

In [None]:
clustering_kmeans = clustering_kmeans.astype(np.int64)
emn.train(data, clustering_kmeans, centers)
y_pred = emn.predict(data)
add_result('EMN', y_pred, clustering_kmeans, centers)

## CART

In [None]:
for name, reference in clique_baselines.items():
    vanilla_cart = init_tree_by_name('cart')
    vanilla_cart.train(data, reference)
    y_pred = vanilla_cart.predict(data)
    add_result('CART', y_pred, clustering_kmeans, centers)

# Evaluation

In [None]:
metric_names = ["Price of Explainability"]

In [None]:
valuation_df = pd.DataFrame(results, index=metric_names)
valuation_df