In [3]:
import pandas as pd
from scipy import sparse as sps
import os
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets
import matplotlib.pyplot as plt
from scipy.cluster import hierarchy
%matplotlib inline
from scipy.stats import wasserstein_distance, energy_distance
from scipy.spatial.distance import squareform
import math
from sklearn.neighbors import radius_neighbors_graph, kneighbors_graph
import time
import sys

In [2]:
def readin_clones(o):
    clone_information = pd.read_csv(o, sep='\t', index_col=0)
    cellid_to_idx = dict(zip(clone_information.index, range(len(clone_information.index))))
    cloneid_to_idx = dict(zip(clone_information.columns, range(len(clone_information.columns))))
    #sparse representation of the clone gr
    #the rows of this correspond to each cell
    #the columns correspond to clone identity
    # if i,j = 1 then cell i belongs to clone group j
    clone_information = clone_information[clone_information.columns[clone_information.sum(axis=0)>0]]
    
    clonegrouping_spmtx = sps.csc_matrix(clone_information.to_numpy())
    return cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx

def readin_coords(o):
    coordinate_information = pd.read_csv(o, sep='\t', index_col = 0)
    coords_mtx = coordinate_information.to_numpy() 
    return coords_mtx

def readin_metad(o):
    metadata_information = pd.read_csv(o, sep='\t', index_col = 0)
    time_vec = metadata_information['time'].to_numpy()
    return time_vec

def readin(data_folder):
    clone_file = os.path.join(data_folder, "clones.tsv.gz")
    coord_file = os.path.join(data_folder, "coordinates.tsv.gz")
    metad_file = os.path.join(data_folder, "metadata.tsv.gz")
    
    cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx = readin_clones(clone_file)
    coords_mtx = readin_coords(coord_file)
    time_vec = readin_metad(metad_file)
    print("Number of cells: ", len(cellid_to_idx))
    print("Number of clones: ", len(cloneid_to_idx))
    print("Number of dimensions: ", coords_mtx.shape[1])
    print("Time Steps: ", np.unique(time_vec))
    return cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec

In [3]:
def getdistance_wasserstein(clonegrouping_spmtx, coords_mtx, time_vec, choice="wasserstein"):
    time_steps = np.unique(time_vec)
    num_clones = clonegrouping_spmtx.shape[1]
    num_dim = coords_mtx.shape[1]
    distance_matrix = np.zeros((num_clones, num_clones))
    num_noninform = 0
    for i in range(num_clones):
        cells_in_i = clonegrouping_spmtx[:,i].nonzero()[0]
        coords_for_i = coords_mtx[cells_in_i]
        time_for_i = time_vec[cells_in_i]
        for j in range(i):
            
            dist = 0
            
            cells_in_j = clonegrouping_spmtx[:,j].nonzero()[0]
            coords_for_j = coords_mtx[cells_in_j]
            time_for_j = time_vec[cells_in_j]
            for t in time_steps:
                #print(t)
                #print(time_for_j)
                #print(time_for_i)
                ts_i = np.where(time_for_i == t)[0]
                ts_j = np.where(time_for_j == t)[0]
                #print(ts_i.size)
                #print(ts_j.size)
                # continue if we can't do anything
                if ts_i.size == 0 or ts_j.size == 0:
                    continue
                
                i_weight = ts_i.shape[0] / cells_in_i.shape[0]
                j_weight = ts_j.shape[0] / cells_in_j.shape[0]
                dists = []
                for d in range(num_dim):
                    if choice == "wasserstein":
                        dists.append(wasserstein_distance(coords_for_i[ts_i][:,d], coords_for_j[ts_j][:,d]))
                    elif choice == "energy":
                        dists.append(energy_distance(coords_for_i[ts_i][:,d], coords_for_j[ts_j][:,d]))
                    else: 
                        print("not supported")
                        sys.exit(2)
                ts_avg = np.mean(dists) * np.mean([i_weight, j_weight])
                dist += ts_avg
            if dist == 0:
                #print(time_for_j)
                #print(time_for_i)
                #L00k -> THESE VALUES SHOULD BE INFINITY
                num_noninform += 1
            distance_matrix[i][j] = dist
    print("Out of " + str(math.comb(num_clones, 2)) + " clonal distances, " + str(num_noninform) + " are noninformative")
    return squareform(distance_matrix + distance_matrix.transpose())

In [105]:
from itertools import combinations

colored = np.array([0, 1, 2])
combinations
empty = np.array([3, 4, 5])
cross = np.array([6, 7])
X = [[0],[0],[1],[0],[1],[1],[0],[1]]
A = radius_neighbors_graph(X, 0.1, mode='connectivity', include_self=True)
#print(A.toarray())
#colored.boom
color_neigh = A[colored]
empty_neigh = A[empty]
cross_neigh = A[cross]

stack = sps.vstack((color_neigh,cross_neigh))
nz = stack.nonzero()

allneighbors_in_both = np.stack((nz[0],nz[1]), axis=-1)
#print(allneighbors_in_both)

in_colored = np.array(np.meshgrid(colored, colored)).T.reshape(-1,2)
#print(in_colored)

in_cross = np.array(np.meshgrid(cross, cross)).T.reshape(-1,2)
#print(in_cross)

in_both = np.array(np.meshgrid(colored, cross)).T.reshape(-1,2)
#print(in_both)

in_all = np.vstack((in_colored, in_cross, in_both))
#print(in_all)
x= 0
#print(x)

allvalues_in_comp = allneighbors_in_both[(allneighbors_in_both[:, None] == in_all).all(-1).any(-1)]
values_in_colored = allneighbors_in_both[(allneighbors_in_both[:, None] == in_colored).all(-1).any(-1)]
i, j = values_in_colored.T
x += (A[i,j].sum())
#print(x)

for coords in range(nz[0].shape[0]):
    #print(nz[0][coords])
    #print(nz[1][coords])
    if (nz[0][coords] in colored or nz[0][coords] in cross) and (nz[1][coords] in cross or nz[1][coords] in colored):
        #print(nz[0][coords], nz[1][coords])
        continue
    else:
        pass

0
5.0
0 0
0 1
0 6
1 0
1 1
1 6
2 2
2 7


In [27]:
def getdistance_mnn(clonegrouping_spmtx, coords_mtx, time_vec, dist="kneighbors", radius=1.0, neighbors=5, mode="distance"):
    time_steps = np.unique(time_vec)
    num_clones = clonegrouping_spmtx.shape[1]
    num_dim = coords_mtx.shape[1]
    distance_matrix = np.zeros((num_clones, num_clones))
    num_noninform = 0
    
    if dist=="kneighbors":
        rng = kneighbors_graph(coords_mtx, neighbors, mode=mode, include_self=True)
    else:
        rng = radius_neighbors_graph(coords_mtx, radius=radius, mode=mode, include_self=True)
    
    
    for i in range(num_clones):
        
        # get the coords for all cells in i
        cells_in_i = clonegrouping_spmtx[:,i].nonzero()[0]
        coords_for_i = coords_mtx[cells_in_i]
        time_for_i = time_vec[cells_in_i]
        
        # get the neighborhood of clone i
        i_neighbors = rng[cells_in_i]
        
        ts_i_dict = {}
        for t in time_steps:
            ts_i = np.where(time_for_i == t)[0]
            if ts_i.size == 0:
                continue
            i_weight = ts_i.shape[0] / cells_in_i.shape[0]
            poss_only_in_i = np.array(np.meshgrid(ts_i, ts_i)).T.reshape(-1,2)
            ts_i_dict[t] = poss_only_in_i

        for j in range(i):
            
            dist = 0
            
            # get cells in clone j
            cells_in_j = clonegrouping_spmtx[:,j].nonzero()[0]
            coords_for_j = coords_mtx[cells_in_j]
            time_for_j = time_vec[cells_in_j]
            
            # get neighborhood of j
            j_neighbors = rng[cells_in_j]

            # stack neighborhoods and extract nonzero values
            stacked = sps.vstack((i_neighbors,j_neighbors), format='dok')
            nz = stacked.nonzero()
            allneighbors_in_both = np.stack((nz[0],nz[1]), axis=-1)

            ts_dists = []
            for t in time_steps:
                ts_i = np.where(time_for_i == t)[0]
                ts_j = np.where(time_for_j == t)[0]

                # continue if we can't do anything
                if ts_i.size == 0 or ts_j.size == 0:
                    continue
                
                denom = 1
                numer = 1
                
                poss_only_in_i = ts_i_dict[t]
                j_weight = ts_j.shape[0] / cells_in_j.shape[0]
                
                poss_only_in_j = np.array(np.meshgrid(ts_j, ts_j)).T.reshape(-1,2)
                poss_in_both = np.array(np.meshgrid(ts_i, ts_j)).T.reshape(-1,2)

                poss_in_all = np.vstack((poss_only_in_i, poss_only_in_j, poss_in_both))
                obs_in_all = allneighbors_in_both[(allneighbors_in_both[:, None] == poss_in_all).all(-1).any(-1)]

                #print(in_all)
                
                obs_in_i = obs_in_all[(obs_in_all[:, None] == poss_only_in_i).all(-1).any(-1)]
                obs_in_j = obs_in_all[(obs_in_all[:, None] == poss_only_in_j).all(-1).any(-1)]
                obs_in_both = obs_in_all[(obs_in_all[:, None] == poss_in_both).all(-1).any(-1)]

                
                x,y = obs_in_i.T
                numer += stacked[x,y].sum()
                x,y = obs_in_j.T
                numer += stacked[x,y].sum()
                x,y = obs_in_both.T
                denom += stacked[x,y].sum()
                ts_dists.append((abs((numer/denom)-1)))#* np.mean([i_weight, j_weight])
            distance_matrix[i][j] += np.mean(ts_dists)
            #distance_matrix[i][j] += (numer/denom) - 1
            #if np.isclose((numer/denom),1): continue
            #else: distance_matrix[i][j] += (numer/denom)
    
    return squareform(distance_matrix + distance_matrix.transpose())   

In [22]:
#cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec = readin("/home/luak/rotations/pinello/post/MEGATRON/preprocess/morrislab/output")
#squareform_distance_matrix = getdistance_mnn(clonegrouping_spmtx, coords_mtx, time_vec, 1)

In [23]:
def plot_dendrogram(Z, cloneids, num_clusters):
    plt.title('Hierarchical Clustering Dendrogram (truncated)')
    plt.xlabel('Clone Id or (cluster size)')
    plt.ylabel('distance')
    hierarchy.dendrogram(
        Z,
        truncate_mode='lastp',  # show only the last p merged clusters
        p=num_clusters,  # show only the last p merged clusters
        leaf_rotation=90.,
        leaf_font_size=10.,
        show_contracted=True,  # to get a distribution impression in truncated branches
        labels=cloneids
    )
    plt.show()

#plot_dendrogram(Z, [*cloneid_to_idx], 5)

def visualize(Z, cloneids, num_clusters=5):
    plot_dendrogram(Z, cloneids, num_clusters)
    plt.clf()
    
    clusters = fcluster(Z, num_clusters, criterion="maxclust")
    fig, ax = plt.subplots()
    
    interact_manual(plot_meta_clones_with_alltags, 
                    Z=fixed(Z), 
                    cloneids=fixed(cloneids),
                    big_clones=fixed(chunky), 
                    num_clusters=widgets.IntSlider(min=1,max=len(Z)-1,
                                              step=1))

In [26]:
#import pickle
#cellid_to_idx, cloneid_to_idx, clonegrouping_spmtx, coords_mtx, time_vec = readin("/home/luak/rotations/pinello/post/MEGATRON/preprocess/kleinlab/output_subset")
#squareform_distance_matrix = getdistance_wasserstein(clonegrouping_spmtx, coords_mtx, time_vec)
#with open('arya_wasserstein.pkl', 'wb') as f: 
#    pickle.dump(squareform_distance_matrix, f)
#Z = linkage(squareform_distance_matrix, method="ward")
#clone_clusters = fcluster(Z, 5, criterion='maxclust')
#print(clone_clusters)
## visualize(Z, cloneids)
#squareform_distance_matrix = getdistance_mnn(clonegrouping_spmtx, coords_mtx, time_vec, 1)
#Z = linkage(squareform_distance_matrix, method="ward")
#clone_clusters = fcluster(Z, 5, criterion='maxclust')
#print(clone_clusters)