In [None]:
import matplotlib
import matplotlib.pyplot as plt 
from matplotlib.collections import LineCollection
import numpy as np
import pandas as pd
import os
import umap
import datashader as ds
import colorcet as cc
import igraph
import tqdm
from sklearn.neighbors import NearestNeighbors
import leidenalg as la

from dredFISH.Analysis import TissueGraph
from dredFISH.Analysis import Classification
from dredFISH.Visualization import Viz

from dredFISH.Utils.__init__plots import * 
from dredFISH.Utils import powerplots

import importlib
importlib.reload(Viz)
importlib.reload(Classification)
importlib.reload(TissueGraph)

# Tasks
- streamline code
- save results (data frame)
- save figure panels (to be assembled into 1 PDF)

In [None]:
def split_half(x, y, line_seg, consistency='large_x'):
    """
    Consistency=None: does not care left vs right
     = 'large_x': always select the right half
    """
    [[p1x, p1y], [p2x, p2y]] = line_seg
    vx = p2x-p1x
    vy = p2y-p1y
    vn = np.array([-vy, vx]) # normal to the line
    v = np.vstack([x-p1x,y-p1y]).T
    
    cond = v.dot(vn) < 0 # sign split points into left and right
    
    if consistency is None:
        return cond
    elif consistency == "large_x": # select the right hemisphere (bigger x)
        if np.mean(x[cond]) < np.mean(x[~cond]):
            cond = ~cond
        return cond

def adjust_XY_byline(line_seg, XY):
    """
    """
    [[p1x, p1y], [p2x, p2y]] = line_seg
    # line direction
    v = np.array([p2x-p1x, p2y-p1y])
    v = v/np.linalg.norm(v, 2)
    vx, vy = v
    # always points up
    if vy < 0:
        v = -v
    # theta
    theta = np.arccos(v.dot([0,1]))
    if vx < 0:
        theta = -theta
    
    # rotate counter clock wise by theta
    R = np.array([
        [np.cos(theta), -np.sin(theta),], 
        [np.sin(theta),  np.cos(theta),], 
        ])
    XYnew = XY.dot(R.T)
    
    return XYnew

In [None]:
def leiden(G, cells,
           resolution=1, seed=0, n_iteration=2,
           **kwargs,
          ):
    """cells are in order
    """
    partition = la.find_partition(G, 
                                  la.RBConfigurationVertexPartition, # modularity with resolution
                                  resolution_parameter=resolution, seed=seed, n_iterations=n_iteration, **kwargs)
    # get cluster labels from partition
    labels = [0]*(len(cells)) 
    for i, cluster in enumerate(partition):
        for element in cluster:
            labels[element] = i+1
    return labels

#### Load data

In [None]:
respath = '/bigstore/GeneralStorage/fangming/projects/dredfish/figures/'
# build on top of basepth
basepth = '/bigstore/GeneralStorage/Data/dredFISH/Dataset1-t3'
!ls -alhtr $basepth
!head $basepth"/TMG.json"

In [None]:
# load TMG - with a cell layer
TMG = TissueGraph.TissueMultiGraph(basepath=basepth, 
                                   redo=False, # load existing 
                                  )
TMG

In [None]:
# stuff needed from it
# spatial coordinates
layer = TMG.Layers[0]

XY = layer.XY
x, y = XY[:,0], XY[:,1]
###
x, y = y, x # a temporary hack
XY = np.vstack([x,y]).T
###

N = layer.N

# measured basis
ftrs_mat = layer.feature_mat

# add to a df 
df = pd.DataFrame()
df['x'] = x
df['y'] = y
for i in range(24):
    df[f'b{i}'] = ftrs_mat[:,i]
df

In [None]:
# define a line to split things into hemi-coronal sections

# try
line_segs = [
    # [(0,0),(1,1)],
    [(550, -6000), (200, 2000)],
    [(200, 2000), (550, -6000)],
]

# split and adjust
cond = split_half(x, y, line_segs[0])    
XYnew = adjust_XY_byline(line_segs[0], XY)
xnew, ynew = XYnew[:,0], XYnew[:,1]

# fig, axs = plt.subplots(2, 2, figsize=(15*2,10*2))
mosaic="""
AAB
CCD
"""
fig = plt.figure(figsize=(20,20), constrained_layout=True)
axs_dict = fig.subplot_mosaic(mosaic)
for i, (key, ax) in enumerate(axs_dict.items()):
    if i == 0:
        ax.scatter(x, y, s=0.1)
        lc = LineCollection(line_segs, linewidth=1, colors='r') 
        ax.add_collection(lc)
    elif i == 1:
        ax.scatter(x[cond], y[cond], s=0.1)
        lc = LineCollection(line_segs, linewidth=1, colors='r') 
        ax.add_collection(lc)
    elif i == 2:
        ax.scatter(xnew, ynew, s=0.1)
    elif i == 3:
        ax.scatter(xnew[cond], ynew[cond], s=0.1)
    ax.set_aspect('equal')

plt.show()

In [None]:
# add results
df['x2'] = XYnew[:,0]
df['y2'] = XYnew[:,1]
df['semi'] = cond.astype(int)

In [None]:
# save
df.to_csv(os.path.join(basepth, "analysis_dev_v2.csv"), header=True, index=True)

In [None]:
# UMAP
umap_mat = umap.UMAP(n_neighbors=30, min_dist=0.1, random_state=0).fit_transform(ftrs_mat)

# add to a df 
df['umap_x'] = umap_mat[:,0]
df['umap_y'] = umap_mat[:,1]
df

In [None]:
%%time
# create known cell type classifier and train and predict
allen_classifier = Classification.KnownCellTypeClassifier(
    layer, 
    tax_name='Allen_types',
    ref='allen_smrt_dpnmf',
    ref_levels=['class_label', 'neighborhood_label', 'subclass_label'], #, 'cluster_label'], 
    model='knn',
)
allen_classifier.train(verbose=True)
type_mat = allen_classifier.classify()

# add to a df
for i in range(3):
    df[f'ktype_L{i+1}'] = type_mat[:,i]
df

In [None]:
%%time
# feature graph to generete cell types
G = layer.FG
cells = layer.adata.obs.index.values
resolutions = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1,2,5,10]

for r in tqdm.tqdm(resolutions):
    types = leiden(G, cells, resolution=r)
    # add to a df 
    df[f'type_r{r}'] = np.char.add('t', np.array(types).astype(str))
df

In [None]:
# save
df.to_csv(os.path.join(basepth, "analysis_dev_v2.csv"), header=True, index=True)

# Viz

In [None]:
def plot_basis_spatial(df, pmode='full'):
    if pmode == 'full':
        nx, ny = 6, 4
        panel_x, panel_y = 6, 5
        wspace, hspace = 0.05, 0
        title_loc = 'left'
        title_y = 0.9
    elif pmode == 'left_half':
        nx, ny = 6, 4
        panel_x, panel_y = 3, 5
        wspace, hspace = 0.05, 0
        title_loc = 'left'
        title_y = 0.9
    elif pmode == 'right_half':
        nx, ny = 6, 4
        panel_x, panel_y = 3, 5
        wspace, hspace = 0.05, 0
        title_loc = 'right'
        title_y = 0.9
    else:
        raise ValueError("No such mode")
        
    path = os.path.join(respath, 'basis_space.pdf')
    P = powerplots.PlotScale(df['x'].max()-df['x'].min(), 
                             df['y'].max()-df['y'].min(),
                             # npxlx=300,
                             pxl_scale=20,
                            )
    print(P.npxlx, P.npxly)

    fig, axs = plt.subplots(ny, nx, figsize=(nx*panel_x, ny*panel_y))
    for i in range(24):
        ax = axs.flat[i]
        aggdata = ds.Canvas(P.npxlx, P.npxly).points(df, 'x', 'y', agg=ds.mean(f'b{i}'))
        ax.imshow(aggdata, origin='lower', aspect='equal', cmap='coolwarm', vmin=-3, vmax=3, interpolation='none')
        ax.set_title(f'b{i}', loc=title_loc, y=title_y)
        ax.axis('off')
    fig.subplots_adjust(wspace=wspace, hspace=hspace)
    # powerplots.savefig_autodate(fig, path)
    plt.show()

def plot_basis_umap(df):
    x, y = 'umap_x', 'umap_y'
    path = os.path.join(respath, 'basis_umap.pdf')
    P = powerplots.PlotScale(df[x].max()-df[x].min(), 
                             df[y].max()-df[y].min(),
                             npxlx=300,
                            )
    print(P.npxlx, P.npxly)

    nx, ny = 6, 4
    fig, axs = plt.subplots(ny, nx, figsize=(nx*5, ny*4))
    for i in range(24):
        ax = axs.flat[i]
        aggdata = ds.Canvas(P.npxlx, P.npxly).points(df, x, y, agg=ds.mean(f'b{i}'))
        ax.imshow(aggdata, origin='lower', aspect='equal', cmap='coolwarm', vmin=-3, vmax=3, interpolation='none')
        ax.set_title(f'b{i}', loc='left', y=0.9)
        ax.axis('off')
    fig.subplots_adjust(wspace=0.05, hspace=0.1)
    # powerplots.savefig_autodate(fig, path)
    plt.show()

In [None]:
%%time

plot_basis_spatial(df)

In [None]:
%%time
plot_basis_umap(df)

In [None]:
%%time
dfsub = df[df['semi']==0]
plot_basis_spatial(dfsub, pmode='left_half')
# plot_basis_umap(dfsub)

dfsub = df[df['semi']==1]
plot_basis_spatial(dfsub, pmode='right_half')


In [None]:
%%time
hue = 'type_r1'
hue_order = np.sort(np.unique(df[hue]))
ntypes = len(hue_order)
        
fig, axs = plt.subplots(1, 2, figsize=(8*2,6))
fig.suptitle(f"{hue}; n={ntypes}")
ax = axs[0]
sns.scatterplot(data=df, x='x', y='y', 
                hue=hue, hue_order=hue_order, 
                s=0.5, edgecolor=None, 
                legend=False,
                ax=ax)
# ax.legend(loc='upper left', bbox_to_anchor=(0, -0.1), ncol=5)
ax.set_aspect('equal')
ax.axis('off')

ax = axs[1]
sns.scatterplot(data=df, x='umap_x', y='umap_y', 
                hue=hue, hue_order=hue_order, 
                s=0.5, edgecolor=None, 
                legend=False,
                ax=ax)
# ax.legend(loc='upper left', bbox_to_anchor=(0, -0.1), ncol=5)
ax.set_aspect('equal')
ax.axis('off')
fig.subplots_adjust(wspace=0)
plt.show()

In [None]:
for r in resolutions:
    hue = f'type_r{r}'
    hue_order = np.sort(np.unique(df[hue]))
    ntypes = len(hue_order)

    fig, axs = plt.subplots(1, 2, figsize=(8*2,6))
    fig.suptitle(f"{hue}; n={ntypes}")
    ax = axs[0]
    sns.scatterplot(data=df, x='x', y='y', 
                    hue=hue, hue_order=hue_order, 
                    s=0.5, edgecolor=None, 
                    legend=False,
                    ax=ax)
    # ax.legend(loc='upper left', bbox_to_anchor=(0, -0.1), ncol=5)
    ax.set_aspect('equal')
    ax.axis('off')

    ax = axs[1]
    sns.scatterplot(data=df, x='umap_x', y='umap_y', 
                    hue=hue, hue_order=hue_order, 
                    s=0.5, edgecolor=None, 
                    legend=False,
                    ax=ax)
    # ax.legend(loc='upper left', bbox_to_anchor=(0, -0.1), ncol=5)
    ax.set_aspect('equal')
    ax.axis('off')
    fig.subplots_adjust(wspace=0)
    plt.show()