In [142]:
import pandas as pd
from scipy import sparse as sps
import os
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets
import matplotlib.pyplot as plt
from scipy.cluster import hierarchy
%matplotlib inline
from scipy.stats import wasserstein_distance, energy_distance
from scipy.spatial.distance import squareform
import math
from sklearn.neighbors import radius_neighbors_graph
import time
import sys

In [136]:
def readin_clones(o):
    clone_information = pd.read_csv(o, sep='\t', index_col=0)
    cellid_to_idx = dict(zip(clone_information.index, range(len(clone_information.index))))
    cloneid_to_idx = dict(zip(clone_information.columns, range(len(clone_information.columns))))
    #sparse representation of the clone gr
    #the rows of this correspond to each cell
    #the columns correspond to clone identity
    # if i,j = 1 then cell i belongs to clone group j
    clone_information = clone_information[clone_information.columns[clone_information.sum(axis=0)>0]]
    
    clonegrouping_spmtx = sps.csc_matrix(clone_information.to_numpy())
    return cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx

def readin_coords(o):
    coordinate_information = pd.read_csv(o, sep='\t', index_col = 0)
    coords_mtx = coordinate_information.to_numpy() 
    return coords_mtx

def readin_metad(o):
    metadata_information = pd.read_csv(o, sep='\t', index_col = 0)
    time_vec = metadata_information['time'].to_numpy()
    return time_vec

def readin(data_folder):
    clone_file = os.path.join(data_folder, "clones.tsv.gz")
    coord_file = os.path.join(data_folder, "coordinates.tsv.gz")
    metad_file = os.path.join(data_folder, "metadata.tsv.gz")
    
    cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx = readin_clones(clone_file)
    coords_mtx = readin_coords(coord_file)
    time_vec = readin_metad(metad_file)
    print("Number of cells: ", len(cellid_to_idx))
    print("Number of clones: ", len(cloneid_to_idx))
    print("Number of dimensions: ", coords_mtx.shape[1])
    print("Time Steps: ", np.unique(time_vec))
    return cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec

In [25]:
def getdistance_wasserstein(clonegrouping_spmtx, coords_mtx, time_vec):
    time_steps = np.unique(time_vec)
    num_clones = clonegrouping_spmtx.shape[1]
    num_dim = coords_mtx.shape[1]
    distance_matrix = np.zeros((num_clones, num_clones))
    num_noninform = 0
    for i in range(num_clones):
        cells_in_i = clonegrouping_spmtx[:,i].nonzero()[0]
        coords_for_i = coords_mtx[cells_in_i]
        time_for_i = time_vec[cells_in_i]
        for j in range(i):
            
            dist = 0
            
            cells_in_j = clonegrouping_spmtx[:,j].nonzero()[0]
            coords_for_j = coords_mtx[cells_in_j]
            time_for_j = time_vec[cells_in_j]
            for t in time_steps:
                #print(t)
                #print(time_for_j)
                #print(time_for_i)
                ts_i = np.where(time_for_i == t)[0]
                ts_j = np.where(time_for_j == t)[0]
                #print(ts_i.size)
                #print(ts_j.size)
                # continue if we can't do anything
                if ts_i.size == 0 or ts_j.size == 0:
                    continue
                
                i_weight = ts_i.shape[0] / cells_in_i.shape[0]
                j_weight = ts_j.shape[0] / cells_in_j.shape[0]
                dists = []
                for d in range(num_dim):
                    dists.append(wasserstein_distance(coords_for_i[ts_i][:,d], coords_for_j[ts_j][:,d]))
                ts_avg = np.mean(dists) * np.mean([i_weight, j_weight])
                dist += ts_avg
            if dist == 0:
                #print(time_for_j)
                #print(time_for_i)
                
                #L00k -> THESE VALUES SHOULD BE INFINITY
                num_noninform += 1
            distance_matrix[i][j] = dist
    print("Out of " + str(math.comb(num_clones, 2)) + " clonal distances, " + str(num_noninform) + " are noninformative")
    return squareform(distance_matrix + distance_matrix.transpose())

In [148]:
def getdistance_mnn(clonegrouping_spmtx, coords_mtx, time_vec, radius):
    time_steps = np.unique(time_vec)
    num_clones = clonegrouping_spmtx.shape[1]
    num_dim = coords_mtx.shape[1]
    distance_matrix = np.zeros((num_clones, num_clones))
    num_noninform = 0
    for i in range(num_clones):
        cells_in_i = clonegrouping_spmtx[:,i].nonzero()[0]
        coords_for_i = coords_mtx[cells_in_i]
        time_for_i = time_vec[cells_in_i]
        for j in range(i):
            
            dist = 0
            
            cells_in_j = clonegrouping_spmtx[:,j].nonzero()[0]
            coords_for_j = coords_mtx[cells_in_j]
            time_for_j = time_vec[cells_in_j]
            
            rng = radius_neighbors_graph(np.vstack((coords_for_i, coords_for_j)), radius=radius, mode="connectivity")
            nz = rng.nonzero()
            isize = cells_in_i.shape[0]
            denom = 1
            numer = 1
            #denom = 0
            #numer = 0
            for coords in range(nz[0].shape[0]):
                if nz[0][coords] < isize and nz[1][coords] < isize:
                    numer += 1
                elif nz[0][coords] >= isize and nz[1][coords] >= isize:
                    numer += 1
                elif (nz[0][coords] >= isize and nz[1][coords] < isize) or \
                (nz[0][coords] < isize and nz[1][coords] >= isize):
                    denom += 1
                else:
                    print("error")
                    sys.exit(2)
            #distance_matrix[i][j] += math.log(numer/denom)
            if np.isclose((numer/denom),1): continue
            else: distance_matrix[i][j] += (numer/denom)
    return squareform(distance_matrix + distance_matrix.transpose())   

In [133]:
cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec = readin("/home/luak/rotations/pinello/post/MEGATRON/preprocess/morrislab/output")
squareform_distance_matrix = getdistance_mnn(clonegrouping_spmtx, coords_mtx, time_vec, 1)

Number of cells:  4498
Number of clones:  781
Number of dimensions:  2
Time Steps:  [ 6  9 12 15 21 28]


In [16]:
def plot_dendrogram(Z, cloneids, num_clusters):
    plt.title('Hierarchical Clustering Dendrogram (truncated)')
    plt.xlabel('Clone Id or (cluster size)')
    plt.ylabel('distance')
    hierarchy.dendrogram(
        Z,
        truncate_mode='lastp',  # show only the last p merged clusters
        p=num_clusters,  # show only the last p merged clusters
        leaf_rotation=90.,
        leaf_font_size=10.,
        show_contracted=True,  # to get a distribution impression in truncated branches
        labels=cloneids
    )
    plt.show()

#plot_dendrogram(Z, [*cloneid_to_idx], 5)

def visualize(Z, cloneids, num_clusters=5):
    plot_dendrogram(Z, cloneids, num_clusters)
    plt.clf()
    
    clusters = fcluster(Z, num_clusters, criterion="maxclust")
    fig, ax = plt.subplots()
    
    interact_manual(plot_meta_clones_with_alltags, 
                    Z=fixed(Z), 
                    cloneids=fixed(cloneids),
                    big_clones=fixed(chunky), 
                    num_clusters=widgets.IntSlider(min=1,max=len(Z)-1,
                                              step=1))

In [None]:
cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec = readin("/home/luak/rotations/pinello/post/MEGATRON/preprocess/morrislab/output")
squareform_distance_matrix = getdistance_wasserstein(clonegrouping_spmtx, coords_mtx, time_vec)
Z = linkage(squareform_distance_matrix)
# visualize(Z, cloneids)

In [154]:
import pickle
cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec = readin("/home/luak/rotations/pinello/post/MEGATRON/preprocess/kleinlab/output_subset")
squareform_distance_matrix = getdistance_wasserstein(clonegrouping_spmtx, coords_mtx, time_vec)
with open('arya_wasserstein.pkl', 'wb') as f: 
    pickle.dump(squareform_distance_matrix, f)
Z = linkage(squareform_distance_matrix, method="ward")
clone_clusters = fcluster(Z, 5, criterion='maxclust')
print(clone_clusters)
# visualize(Z, cloneids)
squareform_distance_matrix = getdistance_mnn(clonegrouping_spmtx, coords_mtx, time_vec, 1)
Z = linkage(squareform_distance_matrix, method="ward")
clone_clusters = fcluster(Z, 5, criterion='maxclust')
print(clone_clusters)

Number of cells:  3221
Number of clones:  5864
Number of dimensions:  2
Time Steps:  [2. 4. 6.]
Out of 66430 clonal distances, 0 are noninformative
[4 5 1 4 3 5 2 3 3 3 4 4 2 3 3 3 3 3 1 2 1 5 4 2 4 4 5 3 5 4 4 4 5 3 5 1 3
 1 1 3 1 3 3 3 3 4 2 1 3 1 3 4 3 1 2 1 2 4 4 3 3 4 4 3 3 3 5 4 4 2 4 3 5 5
 1 5 3 3 5 4 1 3 2 5 3 3 5 5 3 2 3 4 2 5 2 3 4 2 1 1 1 3 1 2 4 1 3 3 4 1 3
 1 4 3 4 3 4 2 3 3 1 3 1 4 3 1 5 2 4 4 2 3 3 1 3 3 2 5 2 4 5 4 2 3 1 3 5 4
 3 5 3 2 3 1 4 5 4 3 5 5 4 4 5 5 2 3 2 1 5 4 5 5 5 3 5 3 3 3 1 2 3 1 1 5 3
 5 5 3 5 1 5 3 3 1 3 5 1 3 1 3 1 5 5 3 4 3 3 2 1 3 4 3 3 3 3 3 2 2 2 3 3 3
 3 3 3 3 1 3 5 5 1 5 1 4 2 5 5 4 5 3 2 3 4 3 1 1 3 5 3 5 4 1 3 1 3 2 1 3 5
 5 4 5 3 3 2 5 4 1 3 1 3 5 5 3 3 4 4 5 5 2 4 1 3 5 4 5 2 4 3 1 5 5 1 2 1 1
 3 1 3 5 3 4 1 4 1 3 5 1 3 3 5 1 3 1 2 3 3 3 1 3 3 5 4 1 3 2 5 2 4 3 3 5 3
 5 1 1 3 4 1 3 3 5 2 5 1 5 3 1 4 4 3 3 3 1 1 5 3 1 5 2 5 3 5 5 4]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1