## MOFAtalk

In [1]:
import os

In [2]:
data_dir = os.path.join('..', '..', 'data', 'kidney_injury')

In [3]:
# columns of interest
sample_key = "ident"
groupby = "cell_type"
condition_key = "Group"

## Setup Environment

In [4]:
import numpy as np
import pandas as pd

import scanpy as sc

In [5]:
import mofax as mofa
import muon as mu
import decoupler as dc

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import liana as li
import plotnine as p9

In [7]:
li.__version__

'0.1.9'

Load object

In [8]:
adata = sc.read_h5ad(os.path.join(data_dir, "aki_processed.h5ad"))

In [9]:
adata.obs.sort_values("cell_type")

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,Group,Replicates,cell_state,ident,cell_type,full_name,n_genes,keep_sum,keep_min,keep_celltype
IRI12h2_CAACTAGGTTTACTCT,IRI12h2,1428.0,532,12hours,2,CNT,IRI12h2,CNT,Connecting tubule,532,24,True,True
IRIsham1b2_CACAGGCTCAAGGTAA,IRIsham1b2,1331.0,579,Control,1_2,CNT,IRIsham1b2,CNT,Connecting tubule,579,24,True,True
IRI14d1b2_ATGGGAGTCTGCGGCA,IRI14d1b2,1279.0,535,14days,1_2,CNT,IRI14d1b2,CNT,Connecting tubule,535,24,True,True
IRIsham2_GACAGAGTCAACCAAC,IRIsham2,1042.0,532,Control,2,CNT,IRIsham2,CNT,Connecting tubule,532,24,True,True
IRI4h3_AAGGTTCTCCAATGGT,IRI4h3,1446.0,553,4hours,3,CNT,IRI4h3,CNT,Connecting tubule,553,24,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...
IRI6w3_ATCTGCCGTGGTGTAG,IRI6w3,1120.0,577,6weeks,3,Uro,IRI6w3,Uro,Urothelial cell,577,23,True,True
IRIsham1b1_CGTCTACAGACGCTTT,IRIsham1b1,1218.0,718,Control,1_1,Uro,IRIsham1b1,Uro,Urothelial cell,718,23,True,True
IRI2d1b1_CGTTAGAAGCTCCTTC,IRI2d1b1,1560.0,667,2days,1_1,Uro,IRI2d1b1,Uro,Urothelial cell,667,23,True,True
IRIsham1b2_CGCTTCATCCTTGACC,IRIsham1b2,1794.0,776,Control,1_2,Uro,IRIsham1b2,Uro,Urothelial cell,776,23,True,True


In [10]:
adata.obs[groupby].value_counts()

PT         50159
MTAL       13750
CTAL        9904
EC          9660
PC          8407
DCT         7500
DTL-ATL     4924
CNT         4578
Fib         4350
MO          3036
ICA         2920
ICB         1974
DCT-CNT     1682
Uro         1167
Pod          771
Tcell        617
PEC          518
Per          396
Name: cell_type, dtype: int64

In [11]:
adata.uns.keys()

dict_keys(['X_name', 'cell_type_colors', 'liana_res', 'log1p', 'neighbors', 'pca', 'umap'])

In [12]:
adata.obs[[groupby, 'full_name']].drop_duplicates().sort_values('cell_type')

Unnamed: 0,cell_type,full_name
IRI4h1_AAATGCCTCAAACCAC,CNT,Connecting tubule
IRI4h1_AAACCTGAGATCTGCT,CTAL,Thick ascending limb of loop of Henle
IRI4h1_AAACCTGCACCAACCG,DCT,Distal convoluted tubule
IRI4h1_AAACCTGTCGTCGTTC,DCT-CNT,DCT-CNT
IRI4h1_AAAGCAACATCGGGTC,DTL-ATL,DTL-ATL (thin ascending limb of loop of Henle)
IRI4h1_AAACCTGAGTGTTAGA,EC,Epithelial cells
IRI4h1_ACACTGAAGAGAGCTC,EC,Epithelial cell
IRI4h1_AAACCTGGTAGCGCTC,Fib,Fibroblast
IRI4h1_AAACCTGCAGGGAGAG,ICA,Intercalated cell of collecting duct
IRI4h1_AAACGGGTCTGAGTGT,ICB,Type B intercalated cell


## Run LIANA

In [None]:
li.method.rank_aggregate.by_sample(adata,
                                   groupby=groupby,
                                   sample_key=sample_key,
                                   resource_name='mouseconsensus',
                                   use_raw=False,
                                   verbose=True,
                                   n_perms=None,
                                   return_all_lrs=False
                                   )

In [None]:
adata.write_h5ad(os.path.join(data_dir, "aki_processed.h5ad"))

In [None]:
adata.uns['liana_res']['ident'].nunique()

## Run MOFA

In [13]:
adata = sc.read_h5ad(os.path.join(data_dir, "aki_processed.h5ad"))

In [14]:
score_key = 'magnitude_rank'

In [15]:
mdata = li.multi.lrs_to_views(adata,
                              sample_key=sample_key,
                              score_key=score_key,
                              obs_keys=[condition_key], # add those to mdata.obs
                              lr_prop = 0.25, # minimum required proportion of samples to keep an LR
                              lrs_per_sample = 15, # minimum number of interactions to keep a sample in a specific view
                              lrs_per_view = 25, # minimum number of interactions to keep a view
                              samples_per_view = 5, # minimum number of samples to keep a view
                              min_variance = 0, # minimum variance to keep an interaction
                              lr_fill = np.nan, # fill missing LR values across samples with this
                              verbose=True
                              )

100%|██████████| 159/159 [00:01<00:00, 121.80it/s]


In [None]:
mu.tl.mofa(mdata,
           use_obs='union',
           convergence_mode='medium',
           outfile='models/mofatalk.h5ad',
           n_factors=5,
           )

In [None]:
# obtain factor scores
factor_scores = li.multi.get_factor_scores(mdata, obsm_key='X_mofa')
# re-order Groups
factor_scores['Group'] = factor_scores['Group'].astype('category')
factor_scores['Group'] = factor_scores['Group'].cat.reorder_categories(['Control', '4hours','12hours',  '2days', '14days', '6weeks'])
factor_scores.head()

UMAP

In [None]:
sc.pp.neighbors(mdata, use_rep="X_mofa")
sc.tl.umap(mdata)

sc.tl.umap(mdata, min_dist=.2, spread=1., random_state=10)

In [None]:
sc.pl.umap(mdata, color="Group", size=150)

In [None]:
# Kruskal Wallis test
from scipy.stats import kruskal

In [None]:
# Extract the values for each group
groups = []
for group_name, group_data in factor_scores.groupby('Group'):
    groups.append(group_data['Factor1'].values)

# Perform the Kruskal-Wallis test
statistic, p_value = kruskal(*groups)

# Print the test results
print("Kruskal-Wallis Test")
print("Statistic:", statistic)
print("P-value:", p_value)

In [None]:
# scatterplot
(p9.ggplot(factor_scores) +
 p9.aes(x='Group', colour='Group', y='Factor1') +
 p9.geom_boxplot() +
 p9.geom_jitter(size=2, width=0.3) +
 p9.theme_bw(base_size=24) +
 p9.theme(figure_size=(6, 6)) + 
 p9.labs(x='Group', y='Factor 1') +
 # rotate x-axis labels
 p9.theme(axis_text_x=p9.element_text(angle=90, hjust=1)) +
 # Dark2 set 
 p9.scale_color_brewer(type='qual', palette='Dark2') +
 # add p-value to top-right
 p9.annotate('text', x=4.3, y=0.25, label=f'KW P-value < {np.ceil(p_value * 100) / 100}', size=24)
 
 )

Plot R-sq

In [None]:
model = mofa.mofa_model("models/mofatalk.h5ad")
model

In [None]:
# get variance explained by view and factor
rsq = model.get_r2()
factor1_rsq = rsq[rsq['Factor']=='Factor1']
# separate view column
factor1_rsq[['source', 'target']] = factor1_rsq['View'].str.split('&', 1, expand=True)

(p9.ggplot(factor1_rsq.reset_index()) +
 p9.aes(x='target', y='source') +
 p9.geom_tile(p9.aes(fill='R2')) +
 p9.scale_fill_gradient2(low='white', high='#c20019') +
 p9.theme_bw(base_size=24) +
 # rotate X axis
 p9.theme(axis_text_x=p9.element_text(angle=90), figure_size=(6, 6)) +
 p9.labs(x='Target', y='Source', fill='  R²')
 )

In [None]:
sources = np.unique(factor1_rsq.sort_values("R2", ascending=False)['source'].head(10).values)
sources

In [None]:
targets = np.unique(factor1_rsq.sort_values("R2", ascending=False)['target'].head(10).values)
targets

Average R2 per source & target

In [None]:
# mean R2 and std for factor1
factor1_rsq.groupby('target').agg({'R2': ['mean', 'std']}).sort_values(('R2', 'mean'), ascending=False).head(10)

In [None]:
# mean R2 and std for factor1
factor1_rsq.groupby('source').agg({'R2': ['mean', 'std']}).sort_values(('R2', 'mean'), ascending=False).head(10)

In [None]:
factor1_rsq['R2'].mean()

In [None]:
factor1_rsq.sort_values("R2", ascending=False).head(10)

In [None]:
adata.obs[[groupby, "full_name"]].drop_duplicates()

In [None]:
variable_loadings =  li.multi.get_variable_loadings(mdata,
                                                    view_separator=':',
                                                    pair_separator="&",
                                                    variable_separator="^") # get loadings for factor 1
variable_loadings.head()
variable_loadings['size'] = 3

In [None]:
# convert to abs to order
variable_loadings['abs_F1'] = np.abs(variable_loadings['Factor1'])

In [None]:
variable_loadings

In [None]:
my_plot = li.pl.dotplot(liana_res = variable_loadings,
                        size='size',
                        colour='Factor1',
                        orderby='abs_F1',
                        top_n=10,
                        source_labels=sources,
                        target_labels=targets,
                        orderby_ascending=False,
                        size_range=(0.1, 5),
                        figure_size=(6, 8)
                        )
# change colour, with mid as white
(my_plot + 
 p9.scale_color_gradient2(low='#1f77b4', mid='lightgray', high='#c20019') + 
 p9.theme_bw(base_size=16) +
 p9.theme(figure_size=(8, 5)) +
p9.theme(axis_text_x=p9.element_text(angle=90)) +
 # remove size from legend
 p9.guides(size=False)
)

Pathway enrichment

In [None]:
lr_loadings =  li.multi.get_variable_loadings(mdata,
                                              view_separator=':',
                                              )
lr_loadings.set_index('variable', inplace=True)

In [None]:
lr_loadings.to_csv(os.path.join(data_dir, 'lr_loadings.csv'))

In [None]:
# load PROGENy pathways
net = dc.get_progeny(organism='Mus musculus', top=5000)
# load full list of ligand-receptor pairs
lr_pairs = li.resource.select_resource('mouseconsensus')

# generate ligand-receptor geneset
lr_progeny = li.fun.generate_lr_geneset(lr_pairs, net, lr_separator="^")
lr_progeny.head()

In [None]:
# NOTE: should instead do enrichment for each view separately?
# here, I'm inflating the number of targets per pathway, e.g. it could be that a line is fit between 2 interactions...

# pivot views to wide
lr_loadings = lr_loadings.pivot(columns='view', values='Factor1')


In [None]:
# replace NaN with 0
lr_loadings.replace(np.nan, 0, inplace=True)

In [None]:
lr_loadings.reset_index(inplace=True)

In [None]:
lr_loadings['view'].unique()

In [None]:
view_enrichments = {}

In [None]:
for view in lr_loadings['view'].unique():
    mat = lr_loadings.loc[lr_loadings['view']==view, ['variable','Factor1']].set_index("variable").transpose()
    try:
        
        estimate, _ = dc.run_ulm(mat, lr_progeny,
                   source="source", target="interaction",
                   use_raw=False, min_n=5)
        view_enrichments[view] = estimate
    except:
        print("Not enough interactions:", view)
        continue


In [None]:
# run pathway enrichment analysis
estimate, pvals =  dc.run_ulm(lr_loadings.transpose(), lr_progeny,
                              source="source", target="interaction",
                              use_raw=False, min_n=5)
# pivot columns to long
estimate = (estimate.
            melt(ignore_index=False, value_name='estimate', var_name='pathway').
            reset_index().
            rename(columns={'index':'view'})
            )


In [None]:
# estimate[['source', 'target']] = estimate['view'].str.split('&', 1, expand=True)
# # source in sources, and target in targets
# estimate = estimate[(estimate['source'].isin(sources)) & (estimate['target'].isin(targets))]

In [None]:
estimate.loc[estimate['estimate'] > 5, "estimate"]=5 
estimate.loc[estimate['estimate'] < -5, "estimate"]=-5

In [None]:
## p9 tile plot
(p9.ggplot(estimate) +
 p9.aes(x='pathway', y='view') +
 p9.geom_tile(p9.aes(fill='estimate')) +
 p9.scale_fill_gradient2(low='#1f77b4', high='#c20019') +
 p9.theme_bw(base_size=14) +
 p9.theme(figure_size=(10, 20))
 # max fill value = 5
)


Explore loadings x PROGENy, for a specific view

In [None]:
factor1_rsq.sort_values("R2", ascending=False).head(10)

In [None]:
# pivot to long
lr_loadings = lr_loadings.melt(ignore_index=False, value_name='loading', var_name='view').reset_index()

In [None]:
lr_loadings.rename(columns={'variable':'interaction'}, inplace=True)

In [None]:
lr_loadings

In [None]:
lr_loadings = lr_loadings.merge(lr_progeny.rename(columns={"source":"pathway"}), left_on='interaction', right_on='interaction', how='left').dropna()

In [None]:
lr_loadings['sign'] = lr_loadings['weight'].apply(lambda x: 'positive' if x > 0 else 'negative')

In [None]:
lr_loadings['relevant_interactions'] = lr_loadings.apply(lambda x: x['interaction'] if (np.abs(x['loading']) > 0) and (np.abs(x['weight']) >= 0.5) else None, axis=1)

In [None]:
### plot fibpod, weight on x, estimate on y

In [None]:
selected_pathway = 'NFkB'
selected_view = 'Fib&PT'


In [None]:
data = lr_loadings[(lr_loadings['pathway']==selected_pathway) & (lr_loadings['view']==selected_view)].copy()

In [None]:
data.sort_values('relevant_interactions', ascending=False)

In [None]:
estimate[estimate['view']==selected_view]

In [None]:

(
    p9.ggplot(data) +
    p9.aes(x='weight', y='loading') +
    p9.geom_point(p9.aes( colour='sign')) +
    p9.theme(legend_position='none') +
    p9.labs(title="{} | {}".format(selected_pathway, selected_view), x="PROGENy Weights", y="Loadings") +
    p9.geom_smooth(method='lm') +
    p9.theme_bw(base_size=14) +
    p9.scale_colour_manual(values=["royalblue", "red"]) +
    p9.geom_label(p9.aes(label='relevant_interactions'), size=10, nudge_y=0.01, nudge_x=0.01) +
    # p9.xlim(-2, 6.5) +
    p9.guides(colour=False)
    
)

In [None]:
estimate.sort_values("estimate")

MSigDB

In [None]:
msigdb = dc.get_resource('MSigDB')
# Filter by hallmark
msigdb = msigdb[msigdb['collection']=='hallmark']

# Remove duplicated entries
msigdb = msigdb[~msigdb.duplicated(['geneset', 'genesymbol'])]

# Rename
msigdb.loc[:, 'geneset'] = [name.split('HALLMARK_')[1] for name in msigdb['geneset']]

In [None]:
msigdb = dc.translate_net(msigdb, 'genesymbol')

In [None]:
msigdb.head()

In [None]:
# generate ligand-receptor geneset
lr_msigdb = li.fun.generate_lr_geneset(lr_pairs, msigdb, lr_separator="^", weight=None, source='geneset', target='genesymbol')
lr_msigdb.head()

In [None]:
# run pathway enrichment analysis
estimate, pvals =  dc.run_ulm(lr_loadings.transpose(), lr_msigdb,
                              source="geneset", target="interaction",
                              use_raw=False, min_n=5)
# pivot columns to long
estimate = (estimate.
            melt(ignore_index=False, value_name='estimate', var_name='pathway').
            reset_index().
            rename(columns={'index':'view'})
            )

In [None]:
## p9 tile plot
(p9.ggplot(estimate) +
 p9.aes(x='pathway', y='view') +
 p9.geom_tile(p9.aes(fill='estimate')) +
 p9.scale_fill_gradient2(low='#1f77b4', high='#c20019') +
 p9.theme_bw(base_size=14) +
 p9.theme(figure_size=(10, 20))
)


In [None]:
model.close()