In [None]:
import matplotlib
import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd
import os
import umap
import datashader as ds
import colorcet as cc
import igraph
import tqdm
from scipy import sparse
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.cluster import KMeans

from matplotlib.collections import PolyCollection
from matplotlib.colors import ListedColormap

from dredFISH.Analysis import TissueGraph
from dredFISH.Visualization import Viz
from dredFISH.Utils.__init__plots import * 
from dredFISH.Utils import powerplots
from dredFISH.Utils import miscu
from dredFISH.Utils import tmgu

import importlib
importlib.reload(Viz)
importlib.reload(TissueGraph)

#### Load data

In [None]:
respath = '/bigstore/GeneralStorage/fangming/projects/dredfish/figures/'

In [None]:
basepth = '/bigstore/GeneralStorage/Data/dredFISH/Dataset1'
!ls -alhtr $basepth
!head $basepth"/TMG.json"

In [None]:
df = pd.read_csv(
    os.path.join(basepth, "analysis_dev.csv"))
df

In [None]:
TMG = TissueGraph.TissueMultiGraph(basepath=basepth, 
                                   redo=False, # load existing 
                                  )

In [None]:
def get_local_typeabundance(SG, types):
    """
    SG - a spatial neighborhood graph (undirected)
    types - type labels on the nodes
    
    return - relative abundace of tyepes for each node
    """
    # regions
    edges = np.asarray(SG.get_edgelist()) 
    ctg, ctg_idx = np.unique(types, return_inverse=True) 

    # once
    i = edges[:,0] # cells
    j = ctg_idx[edges[:,1]] # types it connects
    dat = np.repeat(1, len(i))
    
    # twice
    i2 = edges[:,1] # cells
    j2 = ctg_idx[edges[:,0]] # types it connects
    dat2 = np.repeat(1, len(i2))
    
    # merge
    i = np.hstack([i,i2])
    j = np.hstack([j,j2])
    dat = np.hstack([dat, dat2])

    # count
    env_mat = sparse.coo_matrix((dat, (i,j)), shape=(N, len(ctg))).toarray() # dense
    env_mat = env_mat/env_mat.sum(axis=1).reshape(-1,1)
    env_mat = np.nan_to_num(env_mat, 0)
    
    return env_mat

In [None]:
# spatial coordinates
layer = TMG.Layers[0]
XY = layer.XY
x, y = XY[:,0], XY[:,1]
###
x, y = y, x # a temporary hack
###

cells = layer.adata.obs.index.values

N = layer.N
# measured basis
ftrs_mat = layer.feature_mat

# umap_mat = umap.UMAP(n_neighbors=30, min_dist=0.1).fit_transform(ftrs_mat)

# types
types = df['type_r0.1'].values
# local env
env_mat = get_local_typeabundance(layer.SG, types)

In [None]:
%%time
# region types
k_kms = [5, 10] #[2,5,10,20,50,100]
for k_km in tqdm.tqdm(k_kms):
    kmeans = KMeans(n_clusters=k_km, random_state=1)
    reg_clsts = kmeans.fit_predict(env_mat)
    df[f'type_reg_k{k_km}'] = np.char.add('t', np.array(reg_clsts).astype(str))

In [None]:
for k_km in k_kms:
    hue = f'type_reg_k{k_km}'
    hue_order = np.sort(np.unique(df[hue]))
    ntypes = len(hue_order)

    fig, axs = plt.subplots(1, 2, figsize=(8*2,6))
    fig.suptitle(f"{hue}; n={ntypes}")
    ax = axs[0]
    sns.scatterplot(data=df, x='x', y='y', 
                    hue=hue, hue_order=hue_order, 
                    s=0.5, edgecolor=None, 
                    legend=False,
                    ax=ax)
    # ax.legend(loc='upper left', bbox_to_anchor=(0, -0.1), ncol=5)
    ax.set_aspect('equal')
    ax.axis('off')

    ax = axs[1]
    sns.scatterplot(data=df, x='umap_x', y='umap_y', 
                    hue=hue, hue_order=hue_order, 
                    s=0.5, edgecolor=None, 
                    legend=False,
                    ax=ax)
    # ax.legend(loc='upper left', bbox_to_anchor=(0, -0.1), ncol=5)
    ax.set_aspect('equal')
    ax.axis('off')
    fig.subplots_adjust(wspace=0)
    plt.show()

# Get from region types to regions; and visualize region boundaries 

In [None]:
def plot_poly(TMG, ctg_idx, ctg_colors):
    """
    each cell has a polygon to be colored by categories
    - ctg_idx: category assigments of cells
    - ctg_colors: color assignment of categories
    """
    polys = TMG.Geoms[0]['poly'] #) #) #[0]
    bdbox = np.array(TMG.Geoms[0]['BoundingBox'].exterior.xy).T

    # a hack
    polys = [np.vstack([poly[:,1], poly[:,0]]).T for poly in polys]
    bdbox = np.array(np.vstack([bdbox[:,1], bdbox[:,0]])).T
    # end of the hack

    mx = np.max(bdbox,axis=0)
    mn = np.min(bdbox,axis=0)
    
    fig, ax = plt.subplots(figsize=(12,10))
    p = PolyCollection(polys, edgecolors=('none'), cmap=ListedColormap(ctg_colors)) # cmap=self.clrmp)
    p.set_array(ctg_idx)

    ax.add_collection(p)
    ax.set_aspect('equal') #, 'box')
    ax.set_xlim([mn[0],mx[0]])
    ax.set_ylim([mn[1],mx[1]])
    ax.axis('off')
    ax.set_title(f"n={len(ctg_colors)} region types")
    
    plt.show()

In [None]:
def get_zone_from_types(SG, type_of_cells):
    """
    """
    N = len(type_of_cells)
    
    # trim edges -- remove connects from diffent types
    edges = np.asarray(SG.get_edgelist())
    edges_bytype = type_of_cells[edges]
    edges_sametype = edges[edges_bytype[:,0]==edges_bytype[:,1]]

    # get components (same type and spatially connected); each component is assigned an index
    zones = igraph.Graph(n=N, edges=edges_sametype, directed=False)
    zone_of_cells = np.asarray(zones.components().membership)

    return zone_of_cells

In [None]:
def remove_small_zones(SG, type_of_cells, th=2):
    """
    """
    # cell -> zone (continuity)
    zone_of_cells = get_zone_from_types(SG, type_of_cells)
    # zone stats
    unq_zones, idx, invs, cnts, = np.unique(zone_of_cells, return_index=True, return_inverse=True, return_counts=True)
    # zone -> types (reindexing)
    type_of_zones = type_of_cells[idx]
    # # sanity check
    # print(unq_zones.shape, idx.shape, invs.shape, cnts.shape)
    # print(np.all(unq_zones[invs] == zone_of_cells)) # use invs to recover the original
    # print(np.all(type_of_zones[zone_of_cells] == type_of_cells))
    # print(type_of_cells.shape, zone_of_cells.shape, type_of_zones.shape)
    
    # cell graph to zone graph
    ZSG = SG.copy()
    ZSG.contract_vertices(zone_of_cells)
    ZSG.simplify()

    # trim to edges between bad and good
    zsg_edges = np.asarray(ZSG.get_edgelist())
    zsg_edges_bytype = cnts[zsg_edges] <= th
    zsg_edges_difftype = zsg_edges[np.logical_xor(zsg_edges_bytype[:,0], zsg_edges_bytype[:,1])]

    # sanity check
    print(np.all(np.sum(cnts[zsg_edges_difftype] <= th, axis=1) == 1))

    # edges with the second node bad are swapped to the first
    e1, e2 = zsg_edges_difftype[:,0].copy(), zsg_edges_difftype[:,1].copy()
    cond = cnts[e2]<=th # e2 is bad
    e1sub, e2sub = e1[cond], e2[cond]
    idxsub = np.arange(len(e1))[cond]
    np.put(e1, idxsub, e2sub)
    np.put(e2, idxsub, e1sub)

    # # sanity check
    # print(np.all(cnts[e1]<=th), np.sum(cnts[e2]<=th))

    # uniq
    e1u, e1ui = np.unique(e1, return_index=True)
    e2u = e2[e1ui]

    zone_remap = pd.Series(e2u, index=e1u)
    zones_u = zone_remap.reindex(unq_zones).fillna(pd.Series(unq_zones)).astype(int)
    type_of_zones_u = type_of_zones[zones_u.values]
    type_of_cells_u = type_of_zones_u[zone_of_cells]
    
    return type_of_cells_u

In [None]:
k_km = 10
region_types = df[f'type_reg_k{k_km}'].values
ctg, ctg_idx = np.unique(region_types, return_inverse=True)
colors = sns.color_palette("Set2", len(ctg))
colors

In [None]:
plot_poly(TMG, ctg_idx, colors)

In [None]:
# cell -> zone (continuity)
SG = layer.SG
region_zones = get_zone_from_types(SG, region_types)
# count
unq_zones, idx, invs, cnts, = np.unique(region_zones, return_index=True, return_inverse=True, return_counts=True)

In [None]:
th = 2

fig, ax = plt.subplots()
ax.plot(np.log10(np.sort(cnts)))
ax.axhline(np.log10(th), linestyle='--', color='k')
ax.text(0, np.log10(th*1.3), f"n={th}")
ax.set_xlabel('zones')
ax.set_ylabel('log10(# cells in zone)')

In [None]:
th = 2
binary_types = cnts[invs] > th
ctg, ctg_idx = np.unique(binary_types, return_inverse=True)
colors = sns.color_palette("tab10", len(ctg))
colors

plot_poly(TMG, ctg_idx, colors)

In [None]:
# update region_types
region_types_u = remove_small_zones(SG, region_types, th=2)

# cell -> zone (continuity)
region_zones_u = get_zone_from_types(SG, region_types_u)
# count
unq_zones_u, idx_u, invs_u, cnts_u, = np.unique(region_zones_u, return_index=True, return_inverse=True, return_counts=True)

In [None]:
ctg_u, ctg_idx_u = np.unique(region_types_u, return_inverse=True)
colors_u = sns.color_palette("Set2", len(ctg_u))
colors_u

In [None]:
plot_poly(TMG, ctg_idx_u, colors_u)

In [None]:
fig, ax = plt.subplots()
ax.plot(np.log10(np.sort(cnts_u)))
# ax.axhline(np.log10(th), linestyle='--', color='k')
ax.text(0, np.log10(th*1.3), f"n={th}")
ax.set_xlabel('zones')
ax.set_ylabel('log10(# cells in zone)')

In [None]:
th = 2
binary_types = cnts_u[invs_u] > th
ctg, ctg_idx = np.unique(binary_types, return_inverse=True)
colors = sns.color_palette("tab10", len(ctg))
colors

plot_poly(TMG, ctg_idx, colors)