In [None]:
import matplotlib
import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd
import os
import umap
import datashader as ds
import colorcet as cc
import igraph
import tqdm
from scipy import sparse
from scipy import stats
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import LatentDirichletAllocation
from statsmodels.stats.multitest import multipletests
from sklearn.cluster import KMeans
from scipy.spatial import Delaunay
import json
import itertools
import time

from matplotlib.collections import PolyCollection
from matplotlib.colors import ListedColormap

from dredFISH.Analysis import TissueGraph
from dredFISH.Visualization import Viz
from dredFISH.Utils.__init__plots import * 
from dredFISH.Utils import powerplots
from dredFISH.Utils import miscu
from dredFISH.Utils import tmgu

import importlib
importlib.reload(Viz)
importlib.reload(TissueGraph)

#### Load data

In [None]:
respath = '/bigstore/GeneralStorage/fangming/projects/dredfish/figures/'

In [None]:
basepth = '/bigstore/GeneralStorage/Data/dredFISH/Dataset1-t5'
!ls -alhtr $basepth
!head $basepth"/TMG.json"

In [None]:
df = pd.read_csv(
    os.path.join(basepth, "default_analysis.csv"), index_col=0)
df

In [None]:
TMG = TissueGraph.TissueMultiGraph(basepath=basepth, 
                                   redo=False, # load existing 
                                  )

In [None]:
# spatial coordinates
layer = TMG.Layers[0]
XY = layer.XY
x, y = XY[:,0], XY[:,1]
###
x, y = y, x # a temporary hack
###

cells = layer.adata.obs.index.values

N = layer.N
# measured basis
ftrs_mat = layer.feature_mat

# umap_mat = umap.UMAP(n_neighbors=30, min_dist=0.1).fit_transform(ftrs_mat)



# Cell-cell interaction
- line (counts; distance)
- triangle (counts; distance)
- local envv (clustering)

In [None]:
def count_triplet_old(triplets, types):
    """
    """
    ti = time.time()
    unq_types, types_idx = np.unique(types, return_inverse=True)
    print(time.time()-ti)
    unq_tri, tri_counts = np.unique(np.sort(types_idx[dd.simplices]), return_counts=True, axis=0)
    print(time.time()-ti)
    
    # a = np.sort(types[triplets]).astype(str)
    # print(time.time()-ti)
    # triplet_types = np.char.add(np.char.add(a[:,0], a[:,1]), a[:,2]) #, a[:,2])
    # print(time.time()-ti)
    # uniq_types, counts = np.unique(triplet_types, return_counts=True)
    # print(time.time()-ti)
    return unq_tri, tri_counts

    
def test_triplet_enrichment_old(triplets, node_types, 
                         n_repeat=100, 
                         random_state=0):
    """
    """
    # ti = time.time()
    # house keeping
    triplets = np.asarray(triplets)
    assert triplets.shape[1] == 3
    types = node_types
    np.random.seed(random_state)
    
    # basic stats
    unq_types = np.unique(types)
    n_t = len(unq_types)
    n = len(types)
    unq_types_idx = np.arange(n_t)
    
    # all possible counts
    # a = np.array(list(itertools.combinations_with_replacement(unq_types, 3))).astype(str)
    # att = np.char.add(np.char.add(a[:,0], a[:,1]), a[:,2]) # all triplet types
    tri_codes = np.asarray(list(itertools.combinations_with_replacement(np.arange(n_t), r=3)))
    num_codes = tri_codes[:,0]*n_t*n_t + tri_codes[:,1]*n_t + tri_codes[:,2]
    
    # counts and shuffled conuts
    att_counts = pd.DataFrame(index=att, columns=np.arange(n_repeat+1))
    # print(time.time()-ti)
    # countmats = np.zeros((len(att),n_repeat+1))
    # tri_map = {tuple(i): count for count, i in enumerate(att)}

    # observed counts
    _types, _counts = count_triplet(triplets, types)
    a = unq_types[_types].astype(str)
    _types = np.char.add(np.char.add(a[:,0], a[:,1]), a[:,2]) # all triplet types
    
    # print(time.time()-ti)
    att_counts.loc[_types, 0] = _counts
    # print(time.time()-ti)
    # shuffled counts
    for i in range(1, n_repeat+1):
        # print(time.time()-ti)
        _types, _counts = count_triplet(triplets,
                                        types[np.random.choice(n, size=n, replace=False)], # shuffle
                                        )
        a = unq_types[_types].astype(str)
        _types = np.char.add(np.char.add(a[:,0], a[:,1]), a[:,2]) # all triplet types
        # print(time.time()-ti)
        att_counts.loc[_types, i] = _counts
        # print(time.time()-ti)
        break
    
    # org
    countmats = att_counts.fillna(0).astype(int).values
    
    # count ~ [1, n_rep+1]; percentiles ~ [0,1]
    prctls = (stats.rankdata(countmats, axis=1)[:,0]-1)/n_repeat
    pvals = 2*np.clip(np.minimum(prctls, 1-prctls), 1/n_repeat, None)
    rej, qvals, _, _ = multipletests(pvals, # indep tests
                                     method='fdr_bh')
    
    obs = countmats[:,0]
    exp = np.mean(countmats[:,1:], axis=1)
    enr = (obs+1)/(exp+1)
    return att, pvals, qvals, enr, obs, exp, countmats
        

In [None]:
def count_triplet(triplets, types_idx):
    """
    """
    # ti = time.time()
    unq_tri, tri_counts = np.unique(np.sort(types_idx[dd.simplices]), return_counts=True, axis=0)
    # print(time.time()-ti)
    return unq_tri, tri_counts
    
def test_triplet_enrichment(triplets, node_types, 
                         n_repeat=100, 
                         random_state=0):
    """
    """
    ti = time.time()
    # house keeping
    np.random.seed(random_state)
    
    triplets = np.asarray(triplets)
    assert triplets.shape[1] == 3
    
    types = node_types
    n = len(types)
    unq_types, types_idx = np.unique(types, return_inverse=True)
    n_t = len(unq_types)
    unq_types_idx = np.arange(n_t)
    
    # all possible counts
    tri_codes = np.asarray(list(itertools.combinations_with_replacement(np.arange(n_t), r=3)))
    num_codes = tri_codes[:,0]*n_t*n_t + tri_codes[:,1]*n_t + tri_codes[:,2]
    
    # counts and shuffled conuts
    att_counts = pd.DataFrame(index=num_codes, columns=np.arange(n_repeat+1))
    # countmats = np.zeros((len(att),n_repeat+1))
    # print(time.time()-ti)

    # observed counts
    _types, _counts = count_triplet(triplets, types_idx)
    _types = _types[:,0]*n_t*n_t + _types[:,1]*n_t + _types[:,2]
    # print(time.time()-ti)
    att_counts.loc[_types, 0] = _counts
    
    # shuffled counts
    for i in range(1, n_repeat+1):
        _types, _counts = count_triplet(triplets,
                                        types_idx[np.random.choice(n, size=n, replace=False)], # shuffle
                                        )
        # print(time.time()-ti)
        _types = _types[:,0]*n_t*n_t + _types[:,1]*n_t + _types[:,2]
        att_counts.loc[_types, i] = _counts
        # break
    
    # org
    countmats = att_counts.fillna(0).astype(int).values
    
    # count ~ [1, n_rep+1]; percentiles ~ [0,1]
    prctls = (stats.rankdata(countmats, axis=1)[:,0]-1)/n_repeat
    pvals = 2*np.clip(np.minimum(prctls, 1-prctls), 1/n_repeat, None)
    rej, qvals, _, _ = multipletests(pvals, # indep tests
                                     method='fdr_bh')
    
    obs = countmats[:,0]
    exp = np.mean(countmats[:,1:], axis=1)
    enr = (obs+1)/(exp+1)
    return unq_types, tri_codes, pvals, qvals, enr, obs, exp, countmats
        

In [None]:
N = len(XY)

# from meta
f = '/bigstore/GeneralStorage/fangming/projects/dredfish/data_dump/analysis_meta_Mar31.json'
with open(f, 'r') as fh:
    meta = json.load(fh)
    
dd = Delaunay(XY)
dd.simplices

In [None]:
%%time
triplets = dd.simplices 
typecol = 'type_r0.1'
types = df[typecol].values

(unq_types, tri_codes, pvals, qvals, enr, obs, exp, countmats) = test_triplet_enrichment(
    triplets, types, n_repeat=1000, random_state=0)



In [None]:
alpha = 0.01

res = pd.DataFrame()
res['log2(enr)'] = np.log2(enr)
res['q'] = qvals
res['tri'] = [", ".join(tri_type) for tri_type in unq_types[tri_codes]]
res['sig'] = res['q'] < alpha
res = res.sort_values('log2(enr)', ascending=False)
res.head()

In [None]:
# 
fig, ax = plt.subplots(figsize=(12,4))
sns.barplot(data=res, x='tri', y='log2(enr)', hue='sig', 
            edgecolor='none',
            ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
plt.show()

# second case

In [None]:
%%time
triplets = dd.simplices 
typecol = 'ktype_L3'
types = df[typecol].values

(unq_types, tri_codes, pvals, qvals, enr, obs, exp, countmats) = test_triplet_enrichment(
    triplets, types, n_repeat=1000, random_state=0)



In [None]:
alpha = 0.01

res = pd.DataFrame()
res['log2(enr)'] = np.log2(enr)
res['q'] = qvals
res['tri'] = [", ".join(tri_type) for tri_type in unq_types[tri_codes]]
res['diff-tri'] = tri_codes[:,2] > tri_codes[:,0]
res['ndiff'] = (1+(tri_codes[:,1] > tri_codes[:,0]) + (tri_codes[:,2] > tri_codes[:,1])).astype(str) # categorical

res['sig'] = res['q'] < alpha
res = res.sort_values('log2(enr)', ascending=False)
print(res.shape)
res.head(10)

In [None]:
# check how many sig cases
print(f"Num. possible triplets: {len(res)}")
print(f"Num. significant (FDR<0.01) triplets: {res['sig'].sum()}")
print(f"Num. significant (FDR<0.01) triplets with >2 FC: {np.logical_and(res['sig'], np.abs(res['log2(enr)']) > 1).sum()}")

In [None]:
cond = (np.logical_and(res['sig'], np.abs(res['log2(enr)']) > 1))
res_sgst = res[cond].sort_values('log2(enr)')
res_sgst

In [None]:
colors = sns.color_palette('Set2', 3)
palette = {str(i+1): color for i, color in enumerate(colors)}
palette

In [None]:
# 
toplot = res_sgst.iloc[np.hstack([
    np.arange(10), 
    np.arange(-10, 0),
])]

fig, ax = plt.subplots(figsize=(4,12))
sns.barplot(data=toplot,
            y='tri', x='log2(enr)', hue='ndiff', dodge=False,
            edgecolor='none',
            palette=palette,
            ax=ax)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.legend(title='Num. uniq types')
plt.show()

In [None]:
# top non-self cases
toplot = res_sgst[res_sgst['diff-tri']].iloc[-50:].iloc[::-1]

fig, ax = plt.subplots(figsize=(4,12))
sns.barplot(data=toplot,
            y='tri', x='log2(enr)', hue='ndiff', dodge=False,
            edgecolor='none',
            palette=palette,
            ax=ax)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)
ax.legend(title='Num. uniq types')
plt.show()