## ARI plots

In [None]:
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm
import pickle

from sklearn.utils import check_random_state
import random

from sklearn.cluster import AffinityPropagation
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.cluster import AgglomerativeClustering
# ward
from sklearn.cluster import HDBSCAN
from scanpy.tl import leiden
# this required to install louvain
from scanpy.tl import louvain
import numpy as np
import anndata as ad
import scanpy

from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import fowlkes_mallows_score
from sklearn.metrics import homogeneity_completeness_v_measure
from itertools import combinations
import studenttmixture
from studenttmixture import EMStudentMixture

import matplotlib as mpl

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [None]:
with open(
    'non_isotrop_early_stop_scaled_2kl.pkl',
    'rb'
) as f:
    together = pickle.load(f)

In [None]:
preds = {}
labels = {}
for dim in  sorted(together['Xs'].keys()):
    preds[dim] = {}
    preds[dim]['leiden'] = {}
    preds[dim]['KMeans'] = {}
    preds[dim]['gmm_spherical'] = {}
    preds[dim]['gmm_full'] = {}
    preds[dim]['tmm_alt'] = {}
    preds[dim]['leiden'] = {}
    X  = together['Xs'][dim]
    labels[dim] = together['ys'][dim]
    
    adata = ad.AnnData(X)
    scanpy.pp.neighbors(adata, use_rep='X')
    
    for s in seeds:
        preds[dim]['KMeans'][s] = KMeans(random_state=s, n_clusters=10).fit_predict(X)
        preds[dim]['gmm_spherical'][s] = GaussianMixture(random_state=s, n_components=10, covariance_type='spherical', init_params='k-means++').fit_predict(X)
        preds[dim]['gmm_full'][s] = GaussianMixture(random_state=s, n_components=10, covariance_type='full', init_params='k-means++').fit_predict(X)
        
        tmm_em= EMStudentMixture(n_components=10, tol=1e-5,random_state=s, fixed_df=False, init_type='k++')
        tmm_em.fit(X.astype('float64'))
        preds[dim]['tmm_alt'][s] = tmm_em.predict(X.astype('float64'))

        leiden(adata, random_state=s)
        preds[dim]['leiden'][s] = adata.obs['leiden'].tolist()

In [None]:
ari = {}
for dim in labels.keys():
    ari[dim] = {} 
    for method in preds[dim].keys():
        ari[dim][method] = {} 
        for s in seeds:
            ari[dim][method][s] = adjusted_rand_score(labels[dim], preds[dim][method][s]) 

In [None]:
ari_wo = {}
for dim in labels.keys():
    ari_wo[dim] = {} 
    for method in preds[dim].keys():
        loc_list = []
        for s1, s2 in list(combinations(seeds, 2)):
            loc_list.append(adjusted_rand_score(preds[dim][method][s1], preds[dim][method][s2]))
        ari_wo[dim][method] = loc_list 

In [None]:
preds_bootstrap = {}
idxs1 = list(range(6000))
idxs2 = list(range(4000)) + list(range(6000, 8000))
idxs3 = list(range(4000)) + list(range(8000, 10000))
idxs_all = [idxs1, idxs2, idxs3]
for dim in tqdm(sorted(together['Xs'].keys())):
    preds_bootstrap[dim] = {}
    preds_bootstrap[dim]['tmm_alt'] = {}
    preds_bootstrap[dim]['KMeans'] = {}
    preds_bootstrap[dim]['gmm_spherical'] = {}
    preds_bootstrap[dim]['gmm_full'] = {}
    preds_bootstrap[dim]['leiden'] = {}
    s = 42
    for id_, idx in enumerate(idxs_all):
        X  = together['Xs'][dim][idx]
        tmm_em= EMStudentMixture(n_components=10, tol=1e-5,random_state=s, fixed_df=False, init_type='k++')
        tmm_em.fit(X.astype('float64'))
        preds_bootstrap[dim]['tmm_alt'][id_] = tmm_em.predict(X.astype('float64'))

        preds_bootstrap[dim]['KMeans'][id_] = KMeans(random_state=s, n_clusters=10).fit_predict(X)
        preds_bootstrap[dim]['gmm_spherical'][id_] =  GaussianMixture(random_state=s, n_components=10, 
                                                                      covariance_type='spherical', init_params='k-means++').fit_predict(X)
        preds_bootstrap[dim]['gmm_full'][id_] = GaussianMixture(random_state=s, n_components=10, 
                                                                covariance_type='full', init_params='k-means++').fit_predict(X)

        adata = ad.AnnData(X)
        scanpy.pp.neighbors(adata, use_rep='X')
        leiden(adata, random_state=s)
        preds_bootstrap[dim]['leiden'][id_] = adata.obs['leiden'].tolist()

In [None]:
ari_boot = {}
for dim in labels.keys():
    ari_boot[dim] = {} 
    for method in preds_bootstrap[dim].keys():
        loc_list = []
        for s1, s2 in list(combinations([0,1,2], 2)):
            loc_list.append(adjusted_rand_score(preds_bootstrap[dim][method][s1][:4000], preds_bootstrap[dim][method][s2][:4000]))
        ari_boot[dim][method] = loc_list 

In [None]:
method_names = {
    'KMeans': 'KMeans', 
    'gmm_full': 'GMM full', 
    'gmm_spherical' : 'GMM spherical', 
    'leiden': 'Leiden', 
    'tmm_alt': 'TMM full',
}
use_clust = sorted(list(preds_bootstrap[dim].keys()))

In [None]:
fontsize = 14
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 4), sharey=True)
axes = axes.flatten()

axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)

xs = sorted(list(labels.keys()))
dim = 2
for idx, method in enumerate(use_clust):
    ys = [np.asarray(list(ari[x][method].values())).mean() for x in xs]
    stds = [np.asarray(list(ari[x][method].values())).std() for x in xs]
    ys_up = [ys[i] + stds[i] for i in range(len(ys))]
    ys_down = [ys[i] - stds[i] for i in range(len(ys))]
    axes[0].plot(xs, ys, '-o', label=method_names[method])
    axes[0].fill_between(xs, ys_up, ys_down, alpha=0.5)
# plt.legend(fontsize=14, frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# axes[0].legend(fontsize=fontsize, frameon=False, loc='upper right',)
axes[0].tick_params(axis='both', labelsize=fontsize)
axes[0].set_xlabel('dimensionality', fontsize=fontsize,)
axes[0].set_ylabel('ARI', fontsize=fontsize,)
axes[0].set_xticks(list(preds.keys()))

axes[1].spines['top'].set_visible(False)
axes[1].spines['right'].set_visible(False)

xs = sorted(list(labels.keys()))
dim = 2
for idx, method in enumerate(use_clust):
    xs = sorted(list(labels.keys()))
    ys = [np.asarray(list(ari_wo[x][method])).mean() for x in xs]
    stds = [np.asarray(list(ari_wo[x][method])).std() for x in xs]
    ys_up = [ys[i] + stds[i] for i in range(len(ys))]
    ys_down = [ys[i] - stds[i] for i in range(len(ys))]
    
    axes[1].plot(xs, ys, '-o', label=method_names[method])
    axes[1].fill_between(xs, ys_up, ys_down, alpha=0.5)
axes[1].tick_params(axis='both', labelsize=fontsize)
axes[1].set_xlabel('dimensionality', fontsize=fontsize,)
axes[1].set_ylabel('ARI between seeds', fontsize=fontsize,)
axes[1].set_xticks(list(preds.keys()))


axes[2].spines['top'].set_visible(False)
axes[2].spines['right'].set_visible(False)

xs = sorted(list(labels.keys()))
dim = 2
for idx, method in enumerate(use_clust):
    xs = sorted(list(labels.keys()))
    ys = [np.asarray(list(ari_boot[x][method])).mean() for x in xs]
    stds = [np.asarray(list(ari_boot[x][method])).std() for x in xs]
    ys_up = [ys[i] + stds[i] for i in range(len(ys))]
    ys_down = [ys[i] - stds[i] for i in range(len(ys))]
    axes[2].plot(xs, ys, '-o', label=method_names[method])
    axes[2].fill_between(xs, ys_up, ys_down, alpha=0.5)
axes[2].tick_params(axis='both', labelsize=fontsize)
axes[2].set_xlabel('dimensionality', fontsize=fontsize,)
axes[2].set_ylabel('ARI dataset intersection', fontsize=fontsize,)
axes[2].set_xticks(list(preds.keys()))

axes[2].legend(fontsize=fontsize, frameon=False, bbox_to_anchor=(1.01, 1.0))
axes[0].annotate('A', xy=(0.03, 0.9), xycoords='axes fraction', fontsize=23, fontweight='bold',)
axes[1].annotate('B', xy=(0.03, 0.9), xycoords='axes fraction', fontsize=23, fontweight='bold',)
axes[2].annotate('C', xy=(0.03, 0.9), xycoords='axes fraction', fontsize=23, fontweight='bold',)
plt.tight_layout()
plt.savefig('fig2_with_tmm_mixture_05_v4.png', dpi=300, bbox_inches='tight', transparent=True)

### Density peak by fast search

In [None]:
from scipy.spatial.distance import cdist

In [None]:
def local_density(max_id, distances, dc, guass=True, cutoff=False):
	'''
	Compute all points' local density

	Args:
		max_id    : max continues id
		distances : distance dict
		gauss     : use guass func or not(can't use together with cutoff)
		cutoff    : use cutoff func or not(can't use together with guass)
	
	Returns:
	    local density vector that index is the point index that start from 1
	'''
	assert guass and cutoff == False and guass or cutoff == True
	#logger.info("PROGRESS: compute local density")
	guass_func = lambda dij, dc : math.exp(- (dij / dc) ** 2)
	cutoff_func = lambda dij, dc: 1 if dij < dc else 0
	func = guass and guass_func or cutoff_func
	rho = [-1] + [0] * max_id
	for i in tqdm(range(1, max_id)):
		for j in range(i + 1, max_id + 1):
			rho[i] += func(distances[i, j], dc)
			rho[j] += func(distances[i, j], dc)
		#if i % (max_id / 10) == 0:
			#logger.info("PROGRESS: at index #%i" % (i))
	return np.array(rho, np.float32)

def min_distance(max_id, max_dis, distances, rho):
	'''
	Compute all points' min distance to the higher local density point(which is the nearest neighbor)

	Args:
		max_id    : max continues id
		max_dis   : max distance for all points
		distances : distance dict
		rho       : local density vector that index is the point index that start from 1
	
	Returns:
	    min_distance vector, nearest neighbor vector
	'''
	#logger.info("PROGRESS: compute min distance to nearest higher density neigh")
	sort_rho_idx = np.argsort(-rho)
	delta, nneigh = [0.0] + [float(max_dis)] * (len(rho) - 1), [0] * len(rho)
	delta[sort_rho_idx[0]] = -1.
	for i in tqdm(range(1, max_id)):
		for j in range(0, i):
			old_i, old_j = sort_rho_idx[i], sort_rho_idx[j]
			if distances[old_i, old_j] < delta[old_i]:
				delta[old_i] = distances[old_i, old_j]
				nneigh[old_i] = old_j
		#if i % (max_id / 10) == 0:
			#logger.info("PROGRESS: at index #%i" % (i))
	delta[sort_rho_idx[0]] = max(delta)
	return np.array(delta, np.float32), np.array(nneigh, np.float32)

In [None]:
rhos_perc = {}
nn_perc = {}
delta_perc = {}
dc_perc = {}

for perc in [0.01, 0.1, 0.3, 0.5, 1, 2, 3, 5]:
    dc_perc[perc] = {}
    nn_perc[perc] = {}
    delta_perc[perc] = {}
    rhos_perc[perc] = {}
    
    for dim in [2, 4, 8, 16, 32, 64]:
        X = Xs[dim]
        # X = np.load(f'test_features{dim}.npy')
        dist = cdist(X, X)
        distances_list = dist.flatten()
        dc = np.percentile(distances_list, perc)
        max_dis  = max(distances_list.flatten())
        min_dis = min(dist[dist!=0])
        max_id = dist.shape[0] 
        
        rho_new = local_density(max_id - 1, dist, dc)
        delta_new, nneigh_new = min_distance(max_id-1, max_dis, dist, rho_new)
        
        dc_perc[perc][dim] = dc
        nn_perc[perc][dim] = nneigh_new
        delta_perc[perc][dim] = delta_new
        rhos_perc[perc][dim] = rho_new

In [None]:
for p in dc_perc.keys():
    fig, axes = plt.subplots(nrows=1, ncols=6, figsize=(16, 4))
    axes = axes.flatten()
    for idx, dim in enumerate(dc_perc[p].keys()):
        axes[idx].spines['top'].set_visible(False)
        axes[idx].spines['right'].set_visible(False)
        axes[idx].scatter(rhos_perc[p][dim][1:], delta_perc[p][dim][1:], s=7)
        axes[idx].set_title(f'dim={dim}')
        if idx == 0:
            axes[idx].set_ylabel('delta (dist)')
            axes[idx].set_xlabel('rho (num neighbors)')
    plt.suptitle(f'Threshold percentile={p}')
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=6, figsize=(16, 12))
axes = axes.flatten()
idx = 0
letters = ['A', 'B', 'C', 'D']
for p in [0.01, 0.1, 0.3, 0.5]:
    for idx_, dim in enumerate(dc_perc[p].keys()):
        axes[idx].spines['top'].set_visible(False)
        axes[idx].spines['right'].set_visible(False)
        # together = (rhos_perc[p][dim]*delta_perc[p][dim])
        # together.sort()
        axes[idx].scatter(rhos_perc[p][dim][1:], delta_perc[p][dim][1:], s=10)
        if idx < 6:
            axes[idx].set_title(f'dim={dim}', fontsize=18, fontweight='bold',)
        if idx % 6 == 0:
            axes[idx].set_ylabel(f'{p}-th pth'+r'$\approx$'+f'{dc_perc[p][dim]:.3f}', fontsize=14)
            axes[idx].annotate(letters[int(idx//6)], xy=(0.1, 0.9), xycoords='axes fraction', fontsize=23, fontweight='bold',)
        axes[idx].tick_params(axis='both', which='major', labelsize=14)
        idx += 1
        
            # axes[idx].set_xlabel('rho (num neighbors)')
# plt.suptitle(f'Threshold percentile={p}')
fig.text(-0.01, 0.5, r'$\delta$ (distance to the point with more neighbors)', va='center', rotation='vertical', fontsize=16)
fig.text(0.5, -0.01, r'$\rho$ (local density)', va='center', fontsize=16)

plt.tight_layout()
plt.savefig('appendix_rho_delta_below_1.png', dpi=300, bbox_inches='tight', transparent=True)
plt.show()

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=6, figsize=(16, 12))
axes = axes.flatten()
idx = 0
letters = ['A', 'B', 'C', 'D']
for p in [1, 2, 3, 5]:
    for idx_, dim in enumerate(dc_perc[p].keys()):
        axes[idx].spines['top'].set_visible(False)
        axes[idx].spines['right'].set_visible(False)
        # together = (rhos_perc[p][dim]*delta_perc[p][dim])
        # together.sort()
        axes[idx].scatter(rhos_perc[p][dim][1:], delta_perc[p][dim][1:], s=10)
        if idx < 6:
            axes[idx].set_title(f'dim={dim}', fontsize=18, fontweight='bold',)
        if idx % 6 == 0:
            axes[idx].set_ylabel(f'{p}-th pth'+r'$\approx$'+f'{dc_perc[p][dim]:.3f}', fontsize=14)
            axes[idx].annotate(letters[int(idx//6)], xy=(0.1, 0.9), xycoords='axes fraction', fontsize=23, fontweight='bold',)
        axes[idx].tick_params(axis='both', which='major', labelsize=14)
        idx += 1
        
            # axes[idx].set_xlabel('rho (num neighbors)')
# plt.suptitle(f'Threshold percentile={p}')
fig.text(-0.01, 0.5, r'$\delta$ (distance to the point with more neighbors)', va='center', rotation='vertical', fontsize=16)
fig.text(0.5, -0.01, r'$\rho$ (local density)', va='center', fontsize=16)

plt.tight_layout()
plt.savefig('appendix_rho_delta_above_1.png', dpi=300, bbox_inches='tight', transparent=True)
plt.show()

In [None]:
for p in dc_perc.keys():
    fig, axes = plt.subplots(nrows=1, ncols=6, figsize=(16, 4))
    axes = axes.flatten()
    for idx, dim in enumerate(dc_perc[p].keys()):
        axes[idx].spines['top'].set_visible(False)
        axes[idx].spines['right'].set_visible(False)
        together = (rhos_perc[p][dim]*delta_perc[p][dim])
        together.sort()
        axes[idx].scatter(list(range(30)), together[::-1][:30], s=7)
        axes[idx].set_title(f'dim={dim}')
        if idx == 0:
            axes[idx].set_ylabel('delta* rho')
            # axes[idx].set_xlabel('rho (num neighbors)')
    plt.suptitle(f'Threshold percentile={p}')
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=6, figsize=(16, 12))
axes = axes.flatten()
idx = 0
letters = ['A', 'B', 'C', 'D']
for p in [0.01, 0.1, 0.3, 0.5]:
    for idx_, dim in enumerate(dc_perc[p].keys()):
        axes[idx].spines['top'].set_visible(False)
        axes[idx].spines['right'].set_visible(False)
        together = (rhos_perc[p][dim]*delta_perc[p][dim])
        together.sort()
        axes[idx].scatter(list(range(30)), together[::-1][:30], s=10)
        if idx < 6:
            axes[idx].set_title(f'dim={dim}', fontsize=18, fontweight='bold',)
        if idx % 6 == 0:
            if idx == 0:
                # pass
                axes[idx].set_ylabel(f'{p}-th pth'+r'$\approx$'+f'{dc_perc[p][dim]:.3f}', fontsize=14, labelpad=27)
                # labelpad=50
            else:
                axes[idx].set_ylabel(f'{p}-th pth'+r'$\approx$'+f'{dc_perc[p][dim]:.3f}', fontsize=14)
            axes[idx].annotate(letters[int(idx//6)], xy=(0.1, 0.9), xycoords='axes fraction', fontsize=23, fontweight='bold',)
        axes[idx].tick_params(axis='both', which='major', labelsize=14)
        idx += 1
        
            # axes[idx].set_xlabel('rho (num neighbors)')
# plt.suptitle(f'Threshold percentile={p}')
fig.text(-0.02, 0.5, r'$\rho \cdot \delta$', va='center', rotation='vertical', fontsize=16)
plt.tight_layout()
plt.savefig('appendix_gamma_below_1.png', dpi=300, bbox_inches='tight', transparent=True)
plt.show()