In [1]:
from scipy.sparse import issparse

import numpy as np
import pandas as pd
import scanpy as sc

from liana.method.sp._spatial_pipe import spatial_neighbors
from liana.method.sp._misty import _get_neighbors

from scipy.sparse import csr_matrix
from mudata import MuData
from anndata import AnnData
import os

In [2]:
from liana.method._pipe_utils import prep_check_adata

In [3]:
from liana.mt.sp._misty import _check_target_in_predictors, _check_features, _single_view_model, \
    _mask_connectivity, _multi_model, _format_targets, _format_importances, _concat_dataframes, _check_if_squidpy

In [4]:
def _make_view(adata, connecitivity=None, spatial_key=None, layer=None, use_raw=False):
    adata = prep_check_adata(adata, use_raw=use_raw, layer=layer, groupby=None, min_cells=None)    
    if connecitivity is not None:
        conns = connecitivity
        obsp = {'spatial_connectivities': conns}
    elif spatial_key is not None:
        conns = adata.obsp[spatial_key]
        obsp = {'spatial_connectivities': conns}
    else:
        obsp=None
        
    return AnnData(X=adata.X, var=adata.var, obs=adata.obs, obsp=obsp)

In [5]:
adata = sc.read_h5ad('liana/tests/data/synthetic.h5ad')

MistyData class

In [6]:
class MistyData(MuData):

    def __init__(self, views):
        MuData.__init__(self, views)
        self.view_names = list(self.mod.keys())
        self._check_views()
    
    def _check_views(self):
        assert isinstance(self, MuData), "views must be a MuData object"
        assert "intra" in self.view_names, "views must contain an intra view"
        
        for view in self.view_names:
            if view=="intra":
                continue
            if "spatial_connectivities" not in self.mod[view].obsp.keys():
                raise ValueError(f"view {view} must contain a spatial_connectivities key in .obsp")
    
    ## TODO assign targets & predictors to views
    
            
    def _get_conn(self, view_name):
        return self.mod[view_name].obsp["spatial_connectivities"]

 

### Constructor

In [7]:
# xdata = adata

In [8]:
ydata = None
set_diag = True
spatial_key = 'spatial'
bandwidth = 10
n_neighs = 6
kernel = 'misty_rbf'
cutoff = 0.1
juxta_cutoff = np.inf
zoi = 0

In [9]:
# ydata = ydata if ydata else xdata

In [10]:
def construct_views(xdata, 
                    ydata=None,
                    add_para=True,
                    spatial_key='spatial',
                    set_diag=True, 
                    bandwidth = 10,
                    kernel = 'misty_rbf',
                    zoi = 0,
                    cutoff = 0.1,
                    add_juxta=True,
                    n_neighs = 6,
                    juxta_cutoff = np.inf,
                    **kwargs
                    ):
    
    # init views
    views = {}
    ydata = ydata if ydata else xdata
    
    views['intra'] = _make_view(adata=ydata, use_raw=False, layer=None)
    
    if add_para:
        weights = spatial_neighbors(adata=xdata,
                                    bandwidth=bandwidth,
                                    kernel=kernel,
                                    set_diag=set_diag, 
                                    inplace=False,
                                    cutoff=cutoff,
                                    zoi=zoi
                                    )
        views['para'] = _make_view(xdata, use_raw=False, layer=None, connecitivity=weights)

    if add_juxta:
        sq = _check_if_squidpy()
        neighbors, dists = sq.gr.spatial_neighbors(adata=xdata,
                                                   copy=True,
                                                   spatial_key=spatial_key,
                                                   set_diag=set_diag,
                                                   n_neighs=n_neighs,
                                                   **kwargs
                                                   )
        neighbors[dists > juxta_cutoff] = 0

        views['juxta'] = _make_view(xdata, use_raw=False, layer=None, connecitivity=neighbors)
    
    return views



In [11]:
views = construct_views(adata, delaunay=True, coord_type="generic")

In [12]:
mdata = MistyData(views)



In [13]:
mdata

In [14]:
mdata.view_names

['intra', 'para', 'juxta']

In [15]:
## TODO: two applications
# juxta, para
# misty_lr
# anything else would require a new constructor

In [28]:
predictors = None
targets = None

In [29]:
intra = mdata.mod['intra']

In [30]:
# TODO to be abstracted further
predictors = _check_features(intra, predictors, type_str="predictors")
targets = _check_features(intra, targets, type_str="targets")

In [31]:
targets

['ECM',
 'ligA',
 'ligB',
 'ligC',
 'ligD',
 'prodA',
 'prodB',
 'prodC',
 'prodD',
 'protE',
 'protF']

#### FIT params

In [32]:
n_estimators = 100
n_jobs = -1
seed = 1337
bypass_intra = True
keep_same_predictor = False
k_cv = 10
alphas = [0.1, 1, 10]

In [33]:
group_intra_by = None
group_env_by = None

In [35]:
intra_groups = np.unique(intra.obs[group_intra_by]) if group_intra_by else [None]
# NOTE: this needs to be defined for every extra view
# so, as an abstraction the mdata should be the only place where obs is stored
env_groups = np.unique(intra.obs[group_env_by]) if group_env_by else [None]

In [36]:
view_str = list(mdata.view_names)
if bypass_intra:
    view_str.remove('intra')

In [37]:
# init list to store the results for each intra group and env group as dataframe;
targets_list, importances_list = [], []

In [38]:
# loop over each target and build one RF model for each view
for target in targets:
    
    for intra_group in intra_groups:
        intra_obs_msk = intra.obs[group_intra_by] == \
                intra_group if intra_group else np.ones(intra.shape[0], dtype=bool)
        
        # to array
        y = intra[intra_obs_msk, target].X.toarray().reshape(-1)
        
        # intra is always non-self, while other views can be self
        predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
        preds = predictors if keep_same_predictor else predictors_nonself

        importance_dict = {}
        
        # model the intraview
        if not bypass_intra:
            oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                                 intra,
                                                                                 intra_obs_msk, 
                                                                                 predictors_nonself, 
                                                                                 n_estimators, 
                                                                                 n_jobs, 
                                                                                 seed
                                                                                 )
            if insert_index is not None and keep_same_predictor:
                importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

        # loop over the group_views_by
        for env_group in env_groups:
            # store the oob predictions for each view to construct predictor matrix for meta model
            oob_list = []
            # TODO rename to extra_obs_msk
            env_obs_msk = extra.obs[group_env_by] == env_group if env_group else np.ones(extra.shape[0], dtype=bool)

            if not bypass_intra:
                oob_list.append(oob_predictions_intra)

            # model the juxta and paraview (if applicable)
            ## TODO: remove this thing with all
            for view_name in [v for v in view_str if v != "intra"]:
                extra = mdata.mod[view_name]
                
                connectivity = mdata._get_conn(view_name)
                # NOTE indexing here is expensive, but we do it to avoid memory issues
                view = _mask_connectivity(extra, connectivity, env_obs_msk, preds)
                
                oob_predictions, importance_dict[view_name] = \
                    _single_view_model(y,
                                       view, 
                                       intra_obs_msk, 
                                       preds, 
                                       n_estimators,
                                       n_jobs,
                                       seed
                                       )
                oob_list.append(oob_predictions)


            # train the meta model with k-fold CV 
            intra_r2, multi_r2, coefs = _multi_model(y,
                                                    np.column_stack(oob_list),
                                                    intra_group, 
                                                    bypass_intra, 
                                                    view_str, 
                                                    k_cv, 
                                                    alphas, 
                                                    seed
                                                    )
            
            targets_df = _format_targets(target,
                                        intra_group,
                                        env_group,
                                        view_str,
                                        intra_r2,
                                        multi_r2,
                                        coefs
                                        )
            targets_list.append(targets_df)
            
            importances_df = _format_importances(target, 
                                                preds, 
                                                intra_group, 
                                                env_group,
                                                importance_dict
                                                )
            importances_list.append(importances_df)


# create result dataframes
target_metrics, importances = _concat_dataframes(targets_list,
                                                 importances_list,
                                                 view_str)

In [39]:
target_metrics

Unnamed: 0,target,intra_group,env_group,intra.R2,multi.R2,gain.R2,para,juxta
0,ECM,,,0,0.268114,0.268114,0.22824,0.77176
1,ligA,,,0,0.545693,0.545693,0.129649,0.870351
2,ligB,,,0,0.463586,0.463586,0.214763,0.785237
3,ligC,,,0,0.510295,0.510295,0.374634,0.625366
4,ligD,,,0,0.552079,0.552079,0.134186,0.865814
5,prodA,,,0,0.131187,0.131187,0.0,1.0
6,prodB,,,0,0.154396,0.154396,0.0,1.0
7,prodC,,,0,0.113334,0.113334,0.0,1.0
8,prodD,,,0,0.126405,0.126405,0.0,1.0
9,protE,,,0,0.072021,0.072021,0.0,1.0


In [None]:
importances

In [None]:
views

In [None]:
views = {'intra':intra, 'juxta':juxta, 'para':para}

In [None]:
misty._get_conn("para")

In [None]:
def fit(self,
                 adata,
                 targets = None,
                 predictors = None,
                 keep_same_predictor = False,  # TODO: maybe rename this variable
                 spatial_key = "spatial", 
                 add_juxta = True,
                 add_para = True,
                 bypass_intra = False,
                 group_intra_by = None,
                 group_env_by = None,
                 alphas = [0.1, 1, 10],
                 k_cv = 10,
                 n_estimators = 100,
                 n_jobs = -1,
                 seed = 1337,
                 inplace = True,
                 ):
        

        intra_groups = np.unique(ydata.obs[group_intra_by]) if group_intra_by else [None]
        env_groups = np.unique(xdata.obs[group_env_by]) if group_env_by else [None]

        connectivities = {}
        if add_juxta:
            connectivities['juxta'] = _get_neighbors(xdata,
                                                     juxta_cutoff=juxta_cutoff,
                                                     set_diag=set_diag, 
                                                     spatial_key=spatial_key
                                                     )
        if add_para:
            connectivities['para'] = spatial_neighbors(adata=xdata,
                                                    bandwidth=bandwidth, 
                                                    kernel=kernel,
                                                    set_diag=set_diag, 
                                                    inplace=False,
                                                    cutoff=0,
                                                    zoi=zoi
                                                    )
        view_str = list(connectivities.keys())
        if not bypass_intra:
            view_str = ['intra'] + view_str

        # init list to store the results for each intra group and env group as dataframe;
        targets_list, importances_list = [], []

        # loop over each target and build one RF model for each view
        for target in targets:
            
            for intra_group in intra_groups:
                intra_obs_msk = ydata.obs[group_intra_by] == \
                        intra_group if intra_group else np.ones(ydata.shape[0], dtype=bool)
                
                if issparse(ydata.X):
                    y = np.asarray(ydata[intra_obs_msk, target].X.todense()).reshape(-1)
                else:
                    y = ydata[intra_obs_msk, target].X.reshape(-1)

                # intra is always non-self, while other views can be self
                predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
                preds = predictors if keep_same_predictor else predictors_nonself

                importance_dict = {}
                
                # model the intraview
                if not bypass_intra:
                    oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                                         ydata, 
                                                                                         intra_obs_msk, 
                                                                                         predictors_nonself, 
                                                                                         n_estimators, 
                                                                                         n_jobs, 
                                                                                         seed
                                                                                         )
                    if insert_index is not None and keep_same_predictor:
                        importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

                # loop over the group_views_by
                for env_group in env_groups:
                    # store the oob predictions for each view to construct predictor matrix for meta model
                    oob_list = []
                    
                    env_obs_msk = ydata.obs[group_env_by] == env_group if env_group else np.ones(xdata.shape[0], dtype=bool)

                    if not bypass_intra:
                        oob_list.append(oob_predictions_intra)

                    # model the juxta and paraview (if applicable)
                    ## TODO: remove this thing with all
                    for view_name in [v for v in view_str if v != "intra"]:
                        connectivity = connectivities[view_name]
                        # NOTE indexing here is expensive, but we do it to avoid memory issues
                        view = _mask_connectivity(xdata, connectivity, env_obs_msk, predictors)
                        
                        oob_predictions, importance_dict[view_name] = \
                            _single_view_model(y,
                                            view, 
                                            intra_obs_msk, 
                                            preds, 
                                            n_estimators,
                                            n_jobs,
                                            seed
                                            )
                        oob_list.append(oob_predictions)


                    # train the meta model with k-fold CV 
                    intra_r2, multi_r2, coefs = _multi_model(y,
                                                            np.column_stack(oob_list),
                                                            intra_group, 
                                                            bypass_intra, 
                                                            view_str, 
                                                            k_cv, 
                                                            alphas, 
                                                            seed
                                                            )
                    
                    targets_df = _format_targets(target,
                                                intra_group,
                                                env_group,
                                                view_str,
                                                intra_r2,
                                                multi_r2,
                                                coefs
                                                )
                    targets_list.append(targets_df)
                    
                    importances_df = _format_importances(target, 
                                                        preds, 
                                                        intra_group, 
                                                        env_group,
                                                        importance_dict
                                                        )
                    importances_list.append(importances_df)


        # create result dataframes
        target_metrics, importances = _concat_dataframes(targets_list,
                                                        importances_list,
                                                        view_str)
        if inplace:
            adata.uns["misty_results"] = {"target_metrics": target_metrics,
                                          "importances": importances
                                          }
        else:
            return {"target_metrics": target_metrics,
                    "importances": importances}