In [None]:
import os
import numpy as np
import pandas as pd

import scanpy as sc
import decoupler as dc
import liana as li
import squidpy as sq

import muon as mu

In [None]:
from scipy.sparse import csr_matrix

In [None]:
data_dir = os.path.join('..', '..', 'data', 'heart_visium')

In [None]:
dataset_names = [f for f in os.listdir(data_dir) if f.endswith('.h5ad')]

In [None]:
metadata = pd.read_csv(os.path.join("..", "spot_calling", "visium_meta.csv"))

In [None]:
# get IZ samples
metadata = metadata[metadata['major_labl']=='IZ']

In [None]:
# load collecTRI regulons
net = dc.get_dorothea(organism='human')

In [None]:
sc.set_figure_params(fontsize=16, figsize=(7, 7))

Read in H loadings from NMF

In [None]:
selected_factor = 'Factor4'

In [None]:
lr_loadings = pd.read_csv(os.path.join("results", "lr_loadings.csv"), index_col=0)

In [None]:
lr_loadings.sort_values(selected_factor, ascending=False).head(10)

In [None]:
# histogram
lr_loadings[selected_factor].hist(bins=50, figsize=(10, 10))

In [None]:
# top_features = lr_loadings[lr_loadings['Factor1'] >= 0.6].index
top_features = lr_loadings.sort_values(selected_factor, ascending=False).head(20).index

In [None]:
top_features

Load files, join to MuData

In [None]:
mdatas = {}
target_metrics = {}
interactions = {}

for sample in metadata['slide_name']:
    print(f"Now running: {sample}")
    
    # nmf = sc.read_h5ad(os.path.join('results', 'nmf', sample + '.h5ad'))
    lr = sc.read_h5ad(os.path.join('results', 'lr', sample + '.h5ad'))
    
    adata = sc.read_h5ad(os.path.join('results', 'processed', sample + '.h5ad'))
    li.mt.spatial_neighbors(adata, bandwidth=100, set_diag=False, cutoff=0.01)
    
    # run TF enrichment
    dc.run_ulm(adata,
               net,
               source="source", target="target",
               use_raw=False, min_n=5)
    
    comp = li.fun.obsm_to_adata(adata, 'compositions')
    tf = li.fun.obsm_to_adata(adata, 'ulm_estimate')
    
    mdata = mu.MuData({"lr":lr, "comp":comp, "tf":tf})
    mu.pp.intersect_obs(mdata)
    
    # should I mask to only positive interactions?
    # local_cats = li.fun.obsm_to_adata(mdata.mod['lr'], 'local_cats')
    # msk = local_cats.X <= 0
    # mdata.mod['lr'].X[msk] = 0
    
    # TODO: Misty accepts MuData also, not just dicts
    msk = mdata.mod['lr'].var.index[mdata.mod['lr'].var_names.isin(top_features)]
    
    # TODO, for some reason this causes an error, if not done...
    mdata.mod['comp'].X = csr_matrix(mdata.mod['comp'].X)
    
    mdata.mod['lr'].X = csr_matrix(mdata.mod['lr'].X)
    
    
    mdata.mod['tf'].X = csr_matrix(mdata.mod['tf'].X)

    
    misty = li.mt.MistyData({"intra": mdata.mod['lr'][:, msk], "tf": mdata.mod['tf'], "comp": mdata.mod['comp']})
    misty(model='linear', verbose=True, bypass_intra=True)
    
    target_metrics[sample] = misty.uns['target_metrics']
    interactions[sample] = misty.uns['interactions']
    
    mdatas[sample] = mdata

In [None]:
targets = pd.concat(target_metrics).reset_index().rename(columns={'level_0':'sample'}).drop(columns='level_1').copy()
targets.to_csv(os.path.join("results", "misty_targets.csv"))

In [None]:
ints = pd.concat(interactions).reset_index().rename(columns={'level_0':'sample'}).drop(columns='level_1').copy()
ints.to_csv(os.path.join("results", "misty_interactions.csv"))

In [None]:
li.pl.target_metrics(misty, stat='gain_R2', return_fig=True, figure_size=(10, 5))

In [None]:
li.pl.contributions(misty, return_fig=True, figure_size=(5, 5))

In [None]:
interactions = misty.uns['interactions'].copy()

In [None]:
interactions[interactions['view']=='tf'].sort_values("importances", ascending=False, key=lambda x: abs(x)).head(15)

Plot sample averages

In [None]:
import plotnine as p9

In [None]:
ints = pd.read_csv(os.path.join("results", "misty_interactions.csv"), index_col=0)
targets = pd.read_csv(os.path.join("results", "misty_targets.csv"), index_col=0)

Contributions

In [None]:
contributions = targets[['target', 'tf', 'comp', 'sample']].melt(id_vars=['target', 'sample'], var_name='view', value_name='contribution').copy()

In [None]:
### boxplot, ~target, x=view, y=contribution

(
    p9.ggplot(contributions, p9.aes(x='view', y='contribution', fill='view')) +
    p9.geom_boxplot() +
    p9.facet_wrap('~ target', nrow=1) +
    p9.theme_bw() +
    p9.theme(axis_text_x=p9.element_text(rotation=90),
             figure_size=(10, 5),
             strip_background=p9.element_rect(fill="white"),
             strip_text=p9.element_text(size=8, colour="black", rotation=90),
             axis_text_y=p9.element_text(size=8, colour="black"),
             )
    # flip coord
)


#### R2 multi

In [None]:
targets

In [None]:
# boxplot of y=multi_R2, x=target
(
    p9.ggplot(targets, p9.aes(x='target', y='multi_R2')) +
    p9.geom_boxplot() +
    p9.theme_bw() +
    p9.theme(axis_text_x=p9.element_text(rotation=90)) +
    # limits of y
    p9.ylim(0, 1)
    
) 

### NOTE: there is an issue with NaN values being assigned from one extra to another...
Remove intra_group and extra_group, if they are None.

In [None]:
### boxplot of tf view
sum_ints = ints.dropna().copy()

In [None]:
# mean and std
sum_ints = sum_ints.groupby(['target', 'view', 'predictor']).agg({'importances':['mean', 'std', 'median']}).reset_index()
# reset names
sum_ints.columns = ['target', 'view', 'predictor', 'mean', 'std', 'median']

In [None]:
# heatmap
(
    p9.ggplot(sum_ints[(sum_ints['view']=='comp')], p9.aes(x='predictor', y='target', fill='median')) + 
    p9.geom_tile() +
    p9.theme_bw(base_size=16) +
    p9.theme(axis_text_x=p9.element_text(angle=90)) +
    p9.scale_fill_cmap('coolwarm') +
    p9.theme(figure_size=(5, 5))
)

Top TFs

In [None]:
top_predictors = np.unique(sum_ints[sum_ints['view']=='tf'].sort_values('median', ascending=False).head(100)['predictor'])

In [None]:
# heatmap
(
    p9.ggplot(sum_ints[sum_ints['predictor'].isin(top_predictors)],
              p9.aes(x='predictor', y='target', fill='median')) + 
    p9.geom_tile() +
    p9.theme_bw(base_size=16) +
    p9.theme(axis_text_x=p9.element_text(angle=90)) +
    p9.scale_fill_cmap('coolwarm') +
    p9.theme(figure_size=(8, 5))
    
)

Boxplot of top predictors for interaction x

In [None]:
interaction = 'FN1&ITGA5_ITGB1'

In [None]:
top_ints = sum_ints[sum_ints['target']==interaction].sort_values("median", key=lambda x: abs(x), ascending=False).head(15)

In [None]:
top_predictors = top_ints['predictor'].values

In [None]:
top_ints = ints[(ints['target']==interaction) & (ints['predictor'].isin(top_predictors))].copy()

In [None]:
top_ints['predictor'] = pd.Categorical(top_ints['predictor'], categories=top_predictors, ordered=True)

In [None]:
top_ints['view'].replace({'tf': 'Regulator', 'comp': 'Composition'}, inplace=True)

In [None]:
### boxplot
(
    p9.ggplot(top_ints, p9.aes(x='predictor', y='importances', color='view')) +
    p9.geom_boxplot() +
    p9.theme_minimal(base_size = 14) +
    p9.theme(axis_text_x=p9.element_text(rotation=90), figure_size=(5, 4)) +
    # labels
    p9.labs(x='Predictor', y='Median t-value', color='View')
)

Example Slide

In [None]:
mdatas.keys()

In [None]:
slide = 'AKK003_157775'
comp = 'Fib'
tf = 'PAX6'

In [None]:
import re

In [None]:
mdata = mdatas['AKK003_157775'].copy()

In [None]:
sq.pl.spatial_scatter(mdata.mod['tf'], color=tf, cmap="coolwarm", img_alpha=0.05, size=1.5)

In [None]:
sq.pl.spatial_scatter(mdata.mod['lr'], color=interaction, cmap="cividis", img_alpha=0.05, alpha=1, shape='hex', size=1.5)

In [None]:
local_cats = li.fun.obsm_to_adata(mdata.mod['lr'], 'local_cats')

In [None]:
genes = re.split('[&_]', interaction)
sq.pl.spatial_scatter(adata, color=genes, cmap="cividis", img_alpha=0.05, alpha=1, shape='hex', size=1.5)

In [None]:
sq.pl.spatial_scatter(mdata.mod['comp'], color='Fib', cmap="cividis", img_alpha=0.05, size=1.5)

BASIS

In [None]:
lrdata = mdata.mod['lr']

In [None]:
mdata.obsp = lrdata.obsp.copy()
mdata.obsm = lrdata.obsm.copy()
mdata.uns = lrdata.uns.copy()

In [None]:

# mdata.mod['lr'].X = mdata.mod['lr'].layers['X'].copy()
# mdata.mod['comp'].X = mdata.mod['comp'].layers['X'].copy()

mdata.mod['lr'].layers['X'] = mdata.mod['lr'].X.copy()
mdata.mod['comp'].layers['X'] = mdata.mod['comp'].X.copy()
sc.pp.scale(mdata.mod['lr'])
sc.pp.scale(mdata.mod['comp'])


In [None]:
mdata.mod['tf'].layers['X'] = mdata.mod['tf'].X.copy()
sc.pp.scale(mdata.mod['tf'])

In [None]:
li.mt.bivar(mdata, function_name='cosine', x_mod='tf', y_mod='lr',
            interactions=[(tf, interaction)], positive_only=True, add_categories=True,
            pvalue_method='permutation'
            )

In [None]:
sq.pl.spatial_scatter(mdata.mod['local_scores'], color=f'{tf}^{interaction}', cmap="magma", img_alpha=0.05, size=1.5)

In [None]:
local_cats = li.fun.obsm_to_adata(mdata.mod['local_scores'], 'local_cats')

In [None]:
sq.pl.spatial_scatter(local_cats, color=f'{tf}^{interaction}', cmap="coolwarm", img_alpha=0.05, size=1.5)

In [None]:
local_pvals = li.fun.obsm_to_adata(mdata, 'local_pvals')

In [None]:
sq.pl.spatial_scatter(local_pvals, color=f'{tf}^{interaction}', cmap="Blues", img_alpha=0.05, size=1.5)

In [None]:
nmfdata = sc.read_h5ad(os.path.join('results', 'nmf', slide+'.h5ad'))

In [None]:
# neon cmap
sq.pl.spatial_scatter(nmfdata, color='Factor4', cmap="viridis", img_alpha=0.05, size=1.5)