In [None]:
import matplotlib
import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd
import os
import umap
import datashader as ds
import colorcet as cc
import igraph
import tqdm
from scipy import sparse
from scipy import stats
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import LatentDirichletAllocation
from statsmodels.stats.multitest import multipletests
from sklearn.cluster import KMeans
from scipy.spatial import Delaunay
import json

from matplotlib.collections import PolyCollection
from matplotlib.colors import ListedColormap

from dredFISH.Analysis import TissueGraph
from dredFISH.Visualization import Viz
from dredFISH.Utils.__init__plots import * 
from dredFISH.Utils import powerplots
from dredFISH.Utils import miscu
from dredFISH.Utils import tmgu

import importlib
importlib.reload(Viz)
importlib.reload(TissueGraph)

#### Load data

In [None]:
respath = '/bigstore/GeneralStorage/fangming/projects/dredfish/figures/'

In [None]:
basepth = '/bigstore/GeneralStorage/Data/dredFISH/Dataset1-t5'
!ls -alhtr $basepth
!head $basepth"/TMG.json"

In [None]:
df = pd.read_csv(
    os.path.join(basepth, "default_analysis.csv"), index_col=0)
df

In [None]:
TMG = TissueGraph.TissueMultiGraph(basepath=basepth, 
                                   redo=False, # load existing 
                                  )

In [None]:
# spatial coordinates
layer = TMG.Layers[0]
XY = layer.XY
x, y = XY[:,0], XY[:,1]
###
x, y = y, x # a temporary hack
###

cells = layer.adata.obs.index.values

N = layer.N
# measured basis
ftrs_mat = layer.feature_mat

# umap_mat = umap.UMAP(n_neighbors=30, min_dist=0.1).fit_transform(ftrs_mat)



# Cell-cell interaction
- line (counts; distance)
- triangle (counts; distance)
- local envv (clustering)

In [None]:
def symmetrize(X, indices=None):
    """
    Test case:
        a = np.ones((3,3))
        symmetrize(a)
    """
    X = np.asarray(X)
    m, n = X.shape
    assert m == n
    
    Y = X + X.T
    i, j = np.diag_indices(m)
    Y[i,j] = Y[i,j]/2
    
    if indices is not None:
        Y = pd.DataFrame(Y, index=indices, columns=indices)
    return Y

def count_edgetype(edges, node_types, typeorder=None):
    """
    """
    EL = np.asarray(edges)
    assert EL.shape[1] == 2
    types = node_types
    
    # count edges by types
    countmat = pd.DataFrame(types[EL]).groupby([0,1]).size().unstack().fillna(0).astype(int)
    # symmetrize it
    countmat = symmetrize(countmat, indices=countmat.index.values)
    if typeorder is not None:
        # reordering
        countmat = countmat.reindex(typeorder)[typeorder] # reordering
    
    # make sure the math is correct
    assert np.triu(countmat).sum() == len(EL)
    return countmat

def squareform_with_diag(vec, n_t):
    """
    """
    assert n_t*(n_t+1)/2 == len(vec)
    mat = np.zeros((n_t, n_t))
    mat[np.triu_indices(n_t)] = vec
    mat = symmetrize(mat)
    return mat

def test_edge_enrichment(edges, node_types, typeorder=None, n_repeat=100, random_state=0):
    """
    """
    # house keeping
    EL = np.asarray(edges)
    assert EL.shape[1] == 2
    types = node_types
    np.random.seed(random_state)
    
    # basic stats
    unq_types = np.unique(types)
    n_t = len(unq_types)
    n = len(types)
    
    # counts and shuffled conuts
    countmats = np.zeros((n_t, n_t, n_repeat+1))
    countmat = count_edgetype(EL, 
                              types, # shuffle
                              typeorder=typeorder)
    countmats[:,:,0] = countmat.values
    for i in range(1, n_repeat+1):
        countmat_shuff = count_edgetype(EL, 
                                        types[np.random.choice(n, size=n, replace=False)], # shuffle
                                        typeorder=typeorder)
        countmats[:,:,i] = countmat_shuff.values
        
    # count ~ [1, n_rep+1]; percentiles ~ [0,1]
    prctls = (stats.rankdata(countmats, axis=2)[:,:,0]-1)/n_repeat
    pvals = 2*np.clip(np.minimum(prctls, 1-prctls), 1/n_repeat, None)
    rej, qvals, _, _ = multipletests(pvals[np.triu_indices(n_t)], # indep tests
                                     method='fdr_bh')
    print(qvals.shape, n_t)
    qvals = squareform_with_diag(qvals, n_t)
    enr = (countmat+1)/(np.mean(countmats[:,:,1:], axis=2)+1)
    return countmat, pvals, qvals, enr, countmats

# alpha = 5
# cmat_lo = np.percentile(countmat_shuffs, alpha/2, axis=2)
# cmat_hi = np.percentile(countmat_shuffs, 100-alpha/2, axis=2)
# np.logical_or(countmat < cmat_lo, countmat > cmat_hi)

In [None]:
N = len(XY)

# types
typecol = 'ktype_L3'
types = df[typecol].values

# from meta
f = '/bigstore/GeneralStorage/fangming/projects/dredfish/data_dump/analysis_meta_Mar31.json'
with open(f, 'r') as fh:
    meta = json.load(fh)
typeorder = meta['l3_clsts']

In [None]:
%%time
# edgelist
# layer.SG.simplify() # precaution

EL = np.asarray(layer.SG.get_edgelist())

# types = df['type_r0.1'].values
# typeorder = np.sort(np.unique(types))

types = df['ktype_L3'].values
typeorder = meta['l3_clsts']

countmat, pvals, qvals, enr, countmats = test_edge_enrichment(EL, types, typeorder=typeorder, n_repeat=1000, random_state=0)

In [None]:
countmat_exp = np.mean(countmats[:,:,1:], axis=2)

In [None]:
# fig, axs = plt.subplots(1,11,figsize=(11*3,1*3))
# for i in range(11):
#     # sns.heatmap(np.log10(countmats[:,:,i]+1), 
#     sns.heatmap(countmats[:,:,i], 
#                 ax=axs[i], 
#                 xticklabels=False, 
#                 yticklabels=False, 
#                 cmap='coolwarm',
#                 cbar=False,)

# countmats[:,:,1], countmats[:,:,2], countmats[:,:,3], countmats[:,:,4]
# np.min(pvals), np.min(qvals)

In [None]:
# instances
fig, axs = plt.subplots(1, 2, figsize=(9*2,10))
ax = axs[0]
ax.set_title('Observed')
val = np.log10(countmat+1)
vmax = np.percentile(val, 95)
vmin = np.percentile(val, 5)
cbar_ax = fig.add_axes([0.93, 0.4, 0.02, 0.2])
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='coolwarm', 
            vmax=vmax,
            vmin=vmin,
            ax=ax,
            cbar_ax=cbar_ax,
            cbar_kws=dict(shrink=0.5, label='log10(instances+1)')
           )

ax = axs[1]
ax.set_title('Expected')
val = np.log10(countmat_exp+1)
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=False,
            cmap='coolwarm', 
            vmax=vmax,
            vmin=vmin,
            ax=ax,
            cbar_ax=cbar_ax,
            cbar_kws=dict(shrink=0.5, label='log10(instances+1)')
           )
fig.subplots_adjust(wspace=0.05)
plt.show()


In [None]:
val = np.log2(enr)
fig, ax = plt.subplots(figsize=(12,10))
vmax = np.percentile(val, 95)
vmin = np.percentile(val, 5)
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='coolwarm', 
            vmax=vmax,
            vmin=vmin,
            ax=ax,
            cbar_kws=dict(shrink=0.5, label='log2(obs./exp.)')
           )
ax.set_title('Enrichment of Delaunay edges')
plt.show()

In [None]:
val = qvals
alpha = 0.01
fig, ax = plt.subplots(figsize=(12,10))
# vmax = np.percentile(val, 95)
# vmin = np.percentile(val, 5)
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='rocket', 
            # vmax=vmax,
            # vmin=vmin,
            # center=alpha,
            vmax=alpha,
            # center=alpha,
            ax=ax,
            cbar_kws=dict(shrink=0.5, label='FDR')
           )
ax.set_title('Enrichment of Delaunay edges')
plt.show()

In [None]:
val = np.log2(enr)
val[qvals > alpha] = 0 # insignfiicant

fig, ax = plt.subplots(figsize=(12,10))
vmax = np.percentile(val, 100)
vmin = np.percentile(val, 0)
# vlim = max(vmax, -vmin)
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='coolwarm', 
            vmax=vmax, #vlim,
            vmin=vmin, #-vlim,
            center=0,
            ax=ax,
            cbar_kws=dict(shrink=0.5, label='log2(obs./exp.)\n(only shown FDR<0.01)')
           )
ax.set_title('Enrichment of Delaunay edges')
plt.show()

In [None]:
val = np.log2(enr)
val[qvals > alpha] = 0 # insignfiicant

fig, ax = plt.subplots(figsize=(12,10))
vmax = np.percentile(val, 95)
vmin = np.percentile(val, 5)
vlim = max(vmax, -vmin)
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='coolwarm', 
            vmax=vlim,
            vmin=-vlim,
            center=0,
            ax=ax,
            cbar_kws=dict(shrink=0.5, label='log2(obs./exp.)\n(only shown FDR<0.01)')
           )
ax.set_title('Enrichment of Delaunay edges')
plt.show()

In [None]:
exp = np.mean(countmats[:,:,1:], axis=2)
val = np.log10(exp+1)

fig, ax = plt.subplots(figsize=(12,10))
vmax = np.percentile(val, 95)
vmin = np.percentile(val, 5)
# vlim = max(vmax, -vmin)
sns.heatmap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='coolwarm', 
            vmax=vmax,
            vmin=vmin,
            ax=ax,
            cbar_kws=dict(shrink=0.5, label='log2(obs./exp.)\n(only shown FDR<0.01)')
           )
ax.set_title('Enrichment of Delaunay edges')
plt.show()

In [None]:
val = np.log2(enr)
val[qvals > alpha] = 0 # insignfiicant

vmax = np.percentile(val, 95)
vmin = np.percentile(val, 5)
vlim = max(vmax, -vmin)
g = sns.clustermap(val, 
            xticklabels=True,
            yticklabels=True,
            cmap='coolwarm', 
            vmax=vlim,
            vmin=-vlim,
            # ax=ax,
            cbar_kws=dict(shrink=0.5, label='log2(obs./exp.)\n(only shown FDR<0.01)'),
            figsize=(12,12),
           )

g.fig.axes[0].set_visible(False)
g.fig.axes[1].set_visible(False)
g.fig.axes[2].set_title('Enrichment of Delaunay edges')
plt.show()