In [2]:
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import warnings; warnings.filterwarnings('ignore')
import collections
import string
import scipy
from scipy.io import loadmat
from scipy.io import savemat
from sklearn.cluster import KMeans
from sklearn.cluster import AffinityPropagation, AgglomerativeClustering
from sklearn import manifold
from mpl_toolkits.mplot3d import Axes3D
# This import registers the 3D projection, but is otherwise unused.
from mpl_toolkits.mplot3d import Axes3D
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score
from scipy.sparse import csr_matrix

large = 22; med = 16; small = 12
plt.style.use('seaborn-whitegrid')
sns.set_style("white")

def check_iterable(item):
    if type(item)==str:
        it = [item]
        return it
    else:
        return item

def getknn(a, k):
    return np.partition(a, k)[k]

def load(file_path):
    D = loadmat(file_path)
    X=D['distances']
    X=X+X.T
    X=X.toarray()
    return X

datapath= #Set accordingly

### For clustering analysis, you'll need to provide known cluster labels as ground truth if you want to measure accuracy using NMI and ARI

In [3]:
def load_labels(name):
    labels=[]

    if name=='cancer':
        cells = loadmat(datapath+'/Cancer/breast_cancer_5000_labels.mat')
        cell_labels=cells['y2']
        for item in cell_labels:
            if item=="Tumor":
                labels.append(0)
            if item=="Stromal":
                labels.append(1)
            if item=="Immune":
                labels.append(2)
        y=np.array(labels)
        cdict={0:'Tumor', 1:'Stromal', 2:'Immune'}
        return [y, cdict]

    if name=='deng':
        cells = loadmat(datapath+'/Deng/deng_cells.mat')
        cell_labels=cells['labels']
        for item in cell_labels:
            if item=="zygote":
                labels.append(0)
            if item=="2cell ":
                labels.append(1)
            if item=="4cell ":
                labels.append(2)
            if item=="8cell ":
                labels.append(3)
            if item=="16cell":
                labels.append(4)
            if item=="blast ":
                labels.append(5)
        y=np.array(labels)
        cdict={0:'Zygote', 1:'2 cells', 2:'4 cells', 3:'8 cells', 4:'16 cell', 5:'Blastocyst'}
        return [y, cdict]  

### The following method loads the pre-computed Gromov-Wasserstein distance matrix. Make shure you run gidm_example.py script beforehand

In [4]:
def load_gw_distances_and_labels(name):
    if name=='cancer':
        file_path=datapath+'/Cancer/cancerMP_21.npy'
        cancer = np.load(file_path, allow_pickle=True)
        can = cancer+cancer.T
        data={'distances':can, 'labels':load_labels(name)[0]}
    if name=='deng':
        file_path=datapath+'/Deng/dengMP_21.npy'
        deng=np.load(file_path, allow_pickle=True)
        deng=deng+deng.T
        data={'distances':deng, 'labels':load_labels(name)[0]}
    return data 

In [5]:
def apply_gidm(X, neighbours):
    ep=np.array([getknn(X[i,:],neighbours) for i in range(X.shape[0])])
    ep=ep.reshape(ep.shape[0],1)
    bw = np.outer(ep.T, ep)
    D=X**2
    N = np.multiply(D, 1/bw)
    K = np.exp(-1*N)
    sum_of_rows = K.sum(axis=1)
    sum_of_rows=np.sqrt(sum_of_rows)
    sum_of_rows=1/sum_of_rows
    d = np.diag(sum_of_rows)
    K=np.matmul(np.matmul(d,K),d)
    W, V = np.linalg.eig(K)
    
    return W, V

In [6]:
def compute_score_average(V, dims, y, method, scoring='nmi'):
    score=0
    
    if scoring=='nmi':
        if method=='afp':
            clust=AffinityPropagation(random_state=5).fit(V[:,1:dims])
            score=score+normalized_mutual_info_score(y, clust.labels_)
    else:
        if method=='afp':
            clust=AffinityPropagation(random_state=5).fit(V[:,1:dims])
            score=score+adjusted_rand_score(y, clust.labels_)

    return score

In [20]:
def do_clustering(V,dims,y,method):
    if method=='afp':
        clust=AffinityPropagation(random_state=5).fit(V[:,1:dims])
    
    return clust

def get_scores_gidm(data_name,neighbours, scoring=None, method='kmeans'):
    dims = np.array([3,5,10,15,20,50,100,200])
    scores=[]
    data=load_gw_distances_and_labels(data_name)
    X=data['distances']
    y=data['labels']
    W, V = apply_gidm(X, neighbours)

    for i in range(len(dims)):
            scores.append(compute_score_average(V, dims[i], y, method, scoring))
                          
    return [scores, W, V]

In [8]:
def plot_table_gidm(data_names, neighbors, scoring=None, method='kmeans', save_plot=False):
    %matplotlib qt5
    data_names=check_iterable(data_names)
    nmigw=[get_scores_gw(name, neighbors, scoring, method)[0] for name in data_names]
    ylabels=[string.capwords(name) for name in data_names]
    xlabels=['3','5','10','15','20','50','100', '200']

    plt.figure(figsize=(20,15), dpi= 400)
    gwnmi_pl=sns.heatmap(nmigw, xticklabels=xlabels, yticklabels=ylabels, cmap='YlGnBu', cbar=False, annot=True, vmin=0, vmax=1, square=True, linewidths=.5, annot_kws={"fontsize":6})
    gwnmi_pl.set_yticklabels(gwnmi_pl.get_yticklabels(), rotation = 0)

    # Decorations
    plt.title(scoring.upper()+' scores GIDM, k='+str(neighbors), fontsize=12)
    plt.xticks(fontsize=8)
    plt.yticks(fontsize=8)
    if save_plot==True:
        plt.savefig('../scores_table/GIDM'+method+'_'+scoring+'_'+str(neighbors)+'knn.jpg', dpi=350)
        plt.close()
    else:
        plt.show()

In [21]:
def plot_clusters3D(name, neighbours, dims, method='afp', save_plot=False):
    labels, cdict = load_labels(name)
    %matplotlib qt5
    fig = plt.figure(figsize=(20,10), dpi=150)
    V = get_scores_gidm(name, neighbours)[2]
    clust=do_clustering(V,dims,labels,method)
    
    ax = fig.add_subplot(1, 3, 1, projection='3d')
    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    ax.zaxis.set_ticklabels([])
    ax.set_xlabel('$\phi_1$')
    ax.set_ylabel('$\phi_2$')
    ax.set_zlabel('$\phi_3$')
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.set_title('GIDM')
    scat=ax.scatter(V[:, 1], V[:, 2], V[:, 3], c=clust.labels_, cmap='Paired', label=labels)

    handles, lab = scat.legend_elements() 
    ax3 = fig.add_subplot(1, 3, 3)
    ax3.xaxis.set_ticklabels([])
    ax3.yaxis.set_ticklabels([])
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1, box.width, box.height])
    ax3.legend(handles, cdict.values(), title="Cell type", loc=6, ncol=1)
    ax3.axis('off')
    
    #plt.tight_layout() # No overlap of subplots

    if save_plot==True:
        plt.savefig('../visual/'+name+'_vis3D_'+str(neighbours)+'knn_'+str(min(dims))+'-'+str(max(dims))+'dims.jpg', dpi=350)
        plt.close()
    plt.show()

In [10]:
datasets=['cancer', 'deng']

In [12]:
plot_table_gidm(datasets, 30, scoring='ari', method='afp',save_plot=False)

In [22]:
plot_clusters3D('cancer', neighbours=30, dims=50, method='afp')