In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import importlib
import os
import glob

import sys
# sys.path.append('/home/rwollman/MyProjects/AH/Repos/dredFISH')
from dredFISH.Analysis.TissueGraph import *
from dredFISH.Analysis.Classification import *
from dredFISH.Visualization.Viz import *
from dredFISH.Analysis import Classification
from dredFISH.Utils import geomu, fileu

plt.rcParams['figure.facecolor'] = 'white' 

In [None]:
pth = os.getcwd()
parts = os.path.split(pth)
parts2 = os.path.split(parts[0])
parts3 = os.path.split(parts2[0])
inputpath = os.path.join(parts[0],parts3[1])

In [None]:
inputpath

In [None]:
config = fileu.load_config_module(inputpath)
hybes = [c[0] for c in config.bitmap if np.isin(c[0],list(config.encoding_weight_bias.keys()))]
bad_hybes = []
hybes = [i for i in hybes if not i in bad_hybes]
hybes

In [None]:
config.bitmap

In [None]:
TMG = TissueMultiGraph(inputpath = inputpath,redo = True)
TMG.create_cell_layer(build_spatial_graph = False,build_feature_graph = False,hybes=hybes,norm='none')

In [None]:
TMG.Layers[0].adata.var.index

In [None]:
TMG.Layers[0].adata.var.index = hybes

In [None]:
polyt_raw

In [None]:
columns = ['dapi','polyt','polyt_raw','nonspecific_encoding','nonspecific_encoding_raw','nonspecific_readout','nonspecific_readout_raw','size']
for column in columns:
    if not np.isin(column,TMG.Layers[0].adata.obs.columns):
        TMG.Layers[0].adata.obs[column] = 0

### Remove Outlier Cells

In [None]:
# Remove Bad Sections
bad_wells = []
bad_sections = [i for i in TMG.Layers[0].adata.obs['Slice'].unique() if np.isin(i.split('Well')[1].split('-')[0],bad_wells)]
print(bad_sections)
to_keep = np.isin(TMG.Layers[0].adata.obs['Slice'],bad_sections)==False
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
# Correct Non Specific signal
FISHbasis = TMG.Layers[0].adata.X.copy()
non_specific_readout = np.array(TMG.Layers[0].adata.obs['nonspecific_readout'])
non_specific_encoding = np.array(TMG.Layers[0].adata.obs['nonspecific_encoding'])-non_specific_readout
non_specific_encoding[non_specific_encoding<0] = 0
FISHbasis_corrected = FISHbasis.copy()
FISHbasis_corrected = FISHbasis_corrected-non_specific_readout[:,None]
# Add Encoding Correction?
sum_probes = np.sum([c for r,c in config.encoding_weight_bias.items()])
for b,bit in enumerate(TMG.Layers[0].adata.var.index):
    bit_scaling_factor = config.encoding_weight_bias[bit]/(sum_probes/24)
    FISHbasis_corrected[:,b] = FISHbasis_corrected[:,b]-(non_specific_encoding*bit_scaling_factor)
FISHbasis_corrected[FISHbasis_corrected<0] = 0
TMG.Layers[0].adata.X = FISHbasis_corrected

In [None]:
XY = TMG.Layers[0].XY
in_large_comp = geomu.in_graph_large_connected_components(XY,large_comp_def = 0.01,plot_comp = True,max_dist = 100)
np.sum(in_large_comp)/in_large_comp.shape[0]
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(in_large_comp)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
num_values = TMG.Layers[0].adata.obs['dapi'].copy()
vmin,vmax = np.percentile(num_values,[5,99.9])
print(vmin,vmax)
ValueDistributions(TMG,num_values=num_values,title='dapi',log=True,min_line = vmin,max_line=vmax,figsize = (15,4)).show()

to_keep = np.logical_and(num_values>vmin,
                         num_values<vmax) 
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
num_values = np.array(TMG.Layers[0].adata.obs['polyt'].copy()) - np.array(TMG.Layers[0].adata.obs['nonspecific_readout'].copy())
vmin,vmax = np.percentile(num_values[num_values>0],[5,99.9])
print(vmin,vmax)
ValueDistributions(TMG,num_values=num_values,title='polyt',log=True,min_line = vmin,max_line=vmax,figsize = (15,4)).show()

to_keep = np.logical_and(num_values>vmin,
                         num_values<vmax) 
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
num_values = TMG.Layers[0].adata.X.sum(1).copy()
vmin,vmax = np.percentile(num_values[num_values>0],[5,99.9])
print(vmin,vmax)
ValueDistributions(TMG,num_values=num_values,title='sum',log=True,min_line = vmin,max_line=vmax,figsize = (15,4)).show()

to_keep = np.logical_and(num_values>vmin,
                         num_values<vmax) 
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
num_values = TMG.Layers[0].adata.obs['size'].copy()
vmin,vmax = np.percentile(num_values,[1,99.9])
print(vmin,vmax)
ValueDistributions(TMG,num_values=num_values,title='size',log=False,min_line = vmin,max_line=vmax,figsize = (15,4)).show()

to_keep = np.logical_and(num_values>vmin,
                         num_values<vmax) 
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
num_values = np.array(TMG.Layers[0].adata.obs['nonspecific_encoding'].copy()) - np.array(TMG.Layers[0].adata.obs['nonspecific_readout'].copy())
vmin,vmax = np.percentile(num_values,[0,99])
print(vmin,vmax)
ValueDistributions(TMG,num_values=num_values,title='nonspecific_encoding',log=True,min_line = None,max_line=vmax,figsize = (15,4)).show()

to_keep = np.logical_and(num_values>vmin,
                         num_values<vmax) 
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
num_values = TMG.Layers[0].adata.obs['nonspecific_readout'].copy()
vmin,vmax = np.percentile(num_values,[0,99])
print(vmin,vmax)
ValueDistributions(TMG,num_values=num_values,title='nonspecific_readout',log=True,min_line = None,max_line=vmax,figsize = (15,4)).show()

to_keep = np.logical_and(num_values>vmin,
                         num_values<vmax) 
print(f"before filtering: {TMG.N[0]} cells")
TMG.Layers[0].filter(to_keep)
print(f"after filtering: {TMG.N[0]} cells")

In [None]:
columns = ['dapi','polyt','polyt_raw','nonspecific_encoding','nonspecific_encoding_raw','nonspecific_readout','nonspecific_readout_raw','size']
TMG.Layers[0].adata.obs.groupby('Slice')[columns].mean().reset_index()

In [None]:
TMG.Layers[0].adata

### Preview Vectors

In [None]:
FISHbasis = TMG.Layers[0].adata.X.copy()

In [None]:
TMG.Layers[0].adata 

In [None]:
data = TMG.Layers[0].adata
converter = {r:h.split('hybe')[-1] for r,h,c in config.bitmap}
''.join(config.bitmap[np.where(data.X.mean(0)==value)[0][0]][1].split('hybe')[-1]+',' for value in sorted(data.X.mean(0)))

In [None]:
data = TMG.Layers[0].adata
sections = data.obs['Slice'].unique()
bits = np.array(list(data.var.index))
n_sections = sections.shape[0]
n_columns = n_sections
n_rows = data.X.shape[1]

num_plots = n_columns*n_rows
fig, axs = plt.subplots(n_rows, n_columns, figsize=(5*n_columns, 5*n_rows))
x = data.obs['stage_x']
y = data.obs['stage_y']
# Flatten the axs array so that we can iterate over it using a single loop
axs = axs.ravel()
# Iterate over the subplots and plot each scatter plot
i = -1
for b,bit in enumerate(bits):
    for s,section in enumerate(sections):
            m = data.obs['Slice']==section
            i+=1
            c = FISHbasis[m,b].copy()
            vmin,vmax = np.percentile(c[(c!=0)&(np.isnan(c)==False)],[5,95])
            ax = axs[i]
            order = np.argsort(c)
            scatter_plot = ax.scatter(x[m][order], y[m][order],c=c[order],vmin=vmin,vmax=vmax,s=0.05,marker='x',cmap='jet')  # Adjust 's' for marker size if needed
            fig.colorbar(scatter_plot, ax=ax)
            ax.set_title(section+'\n'+bit)
            ax.tick_params(axis='both', which='both', length=0)
            ax.axis('off')
    # break

# Hide empty subplots if there are any
for i in range(num_plots, len(axs)):
    axs[i].axis("off")

# Adjust the layout and padding between subplots
plt.tight_layout()

# Show the plot
plt.show()

### Remove Bad Bits

In [None]:
for bit in range(TMG.Layers[0].adata.X.shape[1]) :
    num_values = TMG.Layers[0].adata.X[:,bit].copy()
    vmin,vmax = np.percentile(num_values,[25,75])
    ValueDistributions(TMG,num_values=num_values,title=TMG.Layers[0].adata.var.index[bit],log=True,min_line = vmin,max_line=vmax,figsize = (15,4)).show()

In [None]:
# # Remove Bad Bits
# M = [np.percentile(TMG.Layers[0].adata.X[:,i],99)>100 for i in range(TMG.Layers[0].adata.shape[1])]
# print(TMG.Layers[0].adata.var.index[M])
# TMG.Layers[0].adata = TMG.Layers[0].adata[:,M]


In [None]:
# num_values = TMG.Layers[0].adata.X.sum(1).copy()
# vmin,vmax = np.percentile(num_values[num_values>0],[5,100])
# print(vmin,vmax)
# ValueDistributions(TMG,num_values=num_values,title='sum',log=True,min_line = vmin,max_line=vmax,figsize = (15,4)).show()

# to_keep = np.logical_and(num_values>vmin,
#                          num_values<vmax) 
# print(f"before filtering: {TMG.N[0]} cells")
# TMG.Layers[0].filter(to_keep)
# print(f"after filtering: {TMG.N[0]} cells")

### Correct for Cell Staining

In [None]:
# """ Cell Staining Correction  SUM """
# FISHbasis = TMG.Layers[0].adata.X.copy()
# correction = np.sum(FISHbasis.copy(),axis=1).mean()/np.sum(FISHbasis.copy(),axis=1)
# FISHbasis_normalized = (FISHbasis.copy().T*correction).T

In [None]:
""" Cell Staining Correction PolyT """
FISHbasis = TMG.Layers[0].adata.X.copy()
num_values = np.array(TMG.Layers[0].adata.obs['polyt']) - np.array(TMG.Layers[0].adata.obs['nonspecific_readout'])
correction = num_values.mean()/num_values
FISHbasis_normalized = (FISHbasis.copy().T*correction).T

In [None]:
data = TMG.Layers[0].adata
sections = data.obs['Slice'].unique()
bits = np.array(list(data.var.index))
n_sections = sections.shape[0]
n_columns = n_sections
n_rows = data.X.shape[1]

num_plots = n_columns*n_rows
fig, axs = plt.subplots(n_rows, n_columns, figsize=(5*n_columns, 5*n_rows))
x = data.obs['stage_x']
y = data.obs['stage_y']
# Flatten the axs array so that we can iterate over it using a single loop
axs = axs.ravel()
# Iterate over the subplots and plot each scatter plot
i = -1
for b,bit in enumerate(bits):
    for s,section in enumerate(sections):
            m = data.obs['Slice']==section
            i+=1
            c = FISHbasis_normalized[m,b].copy()
            vmin,vmax = np.percentile(c[(c!=0)&(np.isnan(c)==False)],[5,95])
            # c = FISHbasis_normalized[m,b].copy()
            ax = axs[i]
            order = np.argsort(c)
            scatter_plot = ax.scatter(x[m][order], y[m][order],c=c[order],vmin=vmin,vmax=vmax,s=0.05,marker='x',cmap='jet')  # Adjust 's' for marker size if needed
            fig.colorbar(scatter_plot, ax=ax)
            ax.set_title(section+'\n'+bit)
            ax.tick_params(axis='both', which='both', length=0)
            ax.axis('off')
    # break

# Hide empty subplots if there are any
for i in range(num_plots, len(axs)):
    axs[i].axis("off")

# Adjust the layout and padding between subplots
plt.tight_layout()

# Show the plot
plt.show()

### Zscore to put all bits in same space

In [None]:
""" Zscore """
data = TMG.Layers[0].adata
section = np.array(data.obs['Slice'])
sections = np.unique(section)
FISHbasis_zscored = np.zeros_like(FISHbasis_normalized.copy())
for i in range(FISHbasis_normalized.shape[1]):
    tc = FISHbasis_normalized[:,i].copy()
    for s in sections:
        m = section==s
        c = tc[m]
        vmin,vmid,vmax = np.percentile(c[np.isnan(c)==False],[25,50,75])
        std =  np.std(c[np.isnan(c)==False])
        c = c-vmid
        c = c/std
        # if vmin!=vmax:
        #     print(vmin,vmid,vmax)
        #     c = c/(vmax-vmin)
        # else:
        #     std =  np.std(c[np.isnan(c)==False])
        #     print(vmid,std)
        #     c = c/std
        tc[m] = c
    FISHbasis_zscored[:,i] = tc
FISHbasis_zscored

In [None]:
data = TMG.Layers[0].adata
sections = data.obs['Slice'].unique()
bits = np.array(list(data.var.index))
n_sections = sections.shape[0]
n_columns = n_sections
n_rows = data.X.shape[1]

num_plots = n_columns*n_rows
fig, axs = plt.subplots(n_rows, n_columns, figsize=(5*n_columns, 5*n_rows))
x = data.obs['stage_x']
y = data.obs['stage_y']
# Flatten the axs array so that we can iterate over it using a single loop
axs = axs.ravel()
# Iterate over the subplots and plot each scatter plot
i = -1
for b,bit in enumerate(bits):
    for s,section in enumerate(sections):
            m = data.obs['Slice']==section
            i+=1
            c = FISHbasis_zscored[m,b].copy()
            vmin,vmax = np.percentile(c[(c!=0)&(np.isnan(c)==False)],[5,95])
            # c = FISHbasis_zscored[m,b].copy()
            ax = axs[i]
            order = np.argsort(c)
            scatter_plot = ax.scatter(x[m][order], y[m][order],c=c[order],vmin=vmin,vmax=vmax,s=0.05,marker='x',cmap='jet')  # Adjust 's' for marker size if needed
            fig.colorbar(scatter_plot, ax=ax)
            ax.set_title(section+'\n'+bit)
            ax.tick_params(axis='both', which='both', length=0)
            ax.axis('off')
    # break

# Hide empty subplots if there are any
for i in range(num_plots, len(axs)):
    axs[i].axis("off")

# Adjust the layout and padding between subplots
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
TMG.Layers[0].adata.X = FISHbasis_zscored

### Preview clustering

In [None]:
import scanpy as sc
import umap
adata = TMG.Layers[0].adata.copy()
# adata.X = FISHbasis_norm
sc.pp.neighbors(adata, n_neighbors=15, use_rep='X') 
sc.tl.umap(adata, min_dist=0.1)
sc.tl.louvain(adata)
sc.pl.umap(adata, color='louvain', show=True)
xy = np.stack([adata.obs['stage_x'],adata.obs['stage_y']])
adata.obsm['stage'] = xy.T
sc.pl.embedding(adata, basis='stage', color='louvain', show=False)

In [None]:
data = adata.copy()
sections = data.obs['Slice'].unique()
for section in sections:
    print(section)
    sc.pl.embedding(data[data.obs['Slice']==section], basis='stage', color='louvain', show=True)

In [None]:
data = adata.copy()
sections = data.obs['Slice'].unique()
bits = np.array(list(data.obs['louvain'].unique()))

n_columns = sections.shape[0]
n_rows = bits.shape[0]
num_plots = n_columns*n_sections
fig, axs = plt.subplots(n_rows, n_columns, figsize=(5*n_columns, 5*n_rows))
x = data.obs['stage_x']
y = data.obs['stage_y']
# Flatten the axs array so that we can iterate over it using a single loop
axs = axs.ravel()
# Iterate over the subplots and plot each scatter plot
i = -1
for b,bit in enumerate(bits):
    for s,section in enumerate(sections):
        m = data.obs['Slice']==section
        cm = data.obs['louvain']==bit
        i+=1
        vmin,vmax = np.percentile(c[(c!=0)&(np.isnan(c)==False)],[20,95])
        ax = axs[i]
        scatter_plot = ax.scatter(x[m&cm], y[m&cm],c=adata.uns['louvain_colors'][b],s=0.1,marker='x')  # Adjust 's' for marker size if needed
        ax.set_title(section+'\n'+bit)
        ax.tick_params(axis='both', which='both', length=0)
        ax.axis('off')
    # break

# Hide empty subplots if there are any
for i in range(num_plots, len(axs)):
    axs[i].axis("off")

# Adjust the layout and padding between subplots
plt.tight_layout()

# Show the plot
plt.show()

## Run TMG

In [None]:
TMG.Layers[0].build_spatial_graph()

In [None]:
# SLOW
TMG.Layers[0].build_feature_graph(metric='correlation')

In [None]:
# Create the classifier
optleiden = Classification.OptimalLeidenUnsupervized(TMG.Layers[0])

In [None]:
# train the classifier
optleiden.train(opt_params={'iters':10, 'n_consensus':1})

In [None]:
# use the classifier to create types and add them to TMG using the Taxonomy created on the fly by the classifier
type_vec = optleiden.classify(TMG.Layers[0].feature_mat)

In [None]:
TMG.add_type_information(0, type_vec, optleiden.tax)

In [None]:
TMG.create_isozone_layer()
logging.info(f"TMG has {len(TMG.Layers)} Layers")

In [None]:
n_topics_list = [2,5,10,15,20,30,50]
n_procs = len(n_topics_list) 

topic_cls = Classification.TopicClassifier(TMG.Layers[0])
topic_cls.train(n_topics_list=n_topics_list, n_procs=n_procs)
topics = topic_cls.classify(topic_cls.Env)

In [None]:
TMG.create_region_layer(topics, topic_cls.tax)
logging.info(f"TMG has {len(TMG.Layers)} Layers")

In [None]:
TMG.add_geoms(geom_types = ["mask","voronoi","isozones","regions"],redo=False)

In [None]:
TMG.save()

In [None]:
TMG = TissueMultiGraph(inputpath=inputpath,redo = False)

In [None]:
V = BasisView(TMG,rotation=-90)
V.show()

In [None]:
SingleMapView(TMG,level_type = "cell", map_type = "random",rotation=-90).show()
SingleMapView(TMG,level_type = "isozone",map_type = "random",rotation=-90).show()
SingleMapView(TMG,level_type = "region",map_type = "random",rotation=-90).show()

In [None]:
colormaps = ['Purples','Oranges','Blues','Greens','Reds','cividis']
SingleMapView(TMG,level_type = "cell", map_type = "type",color_assign_method = 'linkage',colormaps = colormaps,rotation=-90).show()

In [None]:
V = UMAPwithSpatialMap(TMG,qntl = (0.025,0.975),clp_embed = (0.025,0.975),rotation=-90)
V.show()