In [1]:
import numpy as np
import pandas as pd
import scanpy as sc

from liana.method.sp._spatial_pipe import spatial_neighbors

from mudata import MuData
from anndata import AnnData
import os

In [2]:
from liana.method._pipe_utils._pre import _choose_mtx_rep

In [3]:
from liana.mt.sp._misty import _check_target_in_predictors, _check_features, _single_view_model, \
    _mask_connectivity, _multi_model, _format_targets, _format_importances, _concat_dataframes, _check_if_squidpy

In [4]:
def _make_view(adata, obs=None, use_raw=False, layer=None, connecitivity=None, spatial_key=None, verbose=False):
    
    X = _choose_mtx_rep(adata=adata, use_raw=use_raw, layer=layer, verbose=verbose)
    
    if connecitivity is not None:
        obsp = dict()
        obsp[f"{spatial_key}_connectivities"] = connecitivity
    else:
        obsp = None
        
    obsm = dict()
    if spatial_key is not None:
        if spatial_key not in adata.obsm.keys():
            raise ValueError(f"spatial_key {spatial_key} not found in adata.obsm.keys()")
        obsm[spatial_key] = adata.obsm[spatial_key]
        
    return AnnData(X=X, obs=obs, var=pd.DataFrame(index=adata.var_names), obsp=obsp, obsm=obsm)

In [5]:
adata = sc.read_h5ad('liana/tests/data/synthetic.h5ad')

MistyData class

In [6]:
class MistyData(MuData):
    # TODO: change to SpatialData when Squidpy is updated

    def __init__(self, data, obs, spatial_key):
        super().__init__(data)
        self.view_names = list(self.mod.keys())
        self.obs = obs
        self.spatial_key = spatial_key
        self._check_views()
    
    def _check_views(self):
        assert isinstance(self, MuData), "views must be a MuData object"
        assert "intra" in self.view_names, "views must contain an intra view"
        
        for view in self.view_names:
            if view=="intra":
                continue
            if f"{self.spatial_key}_connectivities" not in self.mod[view].obsp.keys():
                raise ValueError(f"view {view} does not contain `{self.spatial_key}_connectivities` key in .obsp")
    
    def _get_conn(self, view_name):
        return self.mod[view_name].obsp[f"{self.spatial_key}_connectivities"]

 

### Constructor

In [8]:
set_diag = True
spatial_key = 'spatial'
bandwidth = 10
n_neighs = 6
kernel = 'misty_rbf'
cutoff = 0.1
juxta_cutoff = np.inf
zoi = 0

In [None]:
def genericMistyData(intra,
                     extra=None,
                     add_para=True,
                     spatial_key='spatial',
                     set_diag=True, 
                     bandwidth = 10,
                     kernel = 'misty_rbf',
                     zoi = 0,
                     cutoff = 0.1,
                     add_juxta=True,
                     n_neighs = 6,
                     juxta_cutoff = np.inf,
                     extra_use_raw=False,
                     extra_layer=None,
                     intra_use_raw=False,
                     intra_layer=None,
                     verbose=False,
                     **kwargs,
                     ):
    
    # init views
    views = {}
    
    # NOTE the intra view is the one with obs
    intra = _make_view(adata=intra, obs=intra.obs, use_raw=intra_use_raw, layer=intra_layer, spatial_key=spatial_key, verbose=verbose)
    views['intra'] = intra
    
    if extra is None:
        extra = intra
    
    if add_para:
        weights = spatial_neighbors(adata=extra,
                                    spatial_key=spatial_key,
                                    bandwidth=bandwidth,
                                    kernel=kernel,
                                    set_diag=set_diag, 
                                    inplace=False,
                                    cutoff=cutoff,
                                    zoi=zoi
                                    )
        views['para'] = _make_view(adata=extra, use_raw=extra_use_raw, layer=extra_layer, spatial_key=spatial_key, connecitivity=weights, verbose=verbose)

    if add_juxta:
        sq = _check_if_squidpy()
        neighbors, dists = sq.gr.spatial_neighbors(adata=extra,
                                                   copy=True,
                                                   spatial_key=spatial_key,
                                                   set_diag=set_diag,
                                                   n_neighs=n_neighs,
                                                   **kwargs
                                                   )
        neighbors[dists > juxta_cutoff] = 0

        views['juxta'] = _make_view(adata=extra, use_raw=extra_use_raw, layer=extra_layer, spatial_key=spatial_key, connecitivity=neighbors, verbose=verbose)
        
    
    
    return MistyData(views, intra.obs, spatial_key)



In [None]:
mdata = genericMistyData(adata, delaunay=True, coord_type="generic")

In [None]:
mdata

In [None]:
mdata.view_names

In [None]:
## TODO: two applications
# juxta, para
# misty_lr
# anything else would require a new constructor

In [None]:
predictors = None
targets = None

In [None]:
intra = mdata.mod['intra']

In [None]:
# TODO to be abstracted further
predictors = _check_features(intra, predictors, type_str="predictors")
targets = _check_features(intra, targets, type_str="targets")

In [None]:
targets

#### FIT params

In [None]:
n_estimators = 20
n_jobs = -1
seed = 1337
bypass_intra = True
keep_same_predictor = False
k_cv = 10
alphas = [0.1, 1, 10]

In [None]:
group_intra_by = None
group_env_by = None

In [None]:
# TODO: function that checks if the groupby is in the obs
# and does this for both extra & intra
intra_groups = np.unique(mdata.obs[group_intra_by]) if group_intra_by else [None]
extra_groups = np.unique(mdata.obs[group_env_by]) if group_env_by else [None]

In [None]:
view_str = list(mdata.view_names)
if bypass_intra:
    view_str.remove('intra')

In [None]:
# init list to store the results for each intra group and env group as dataframe;
targets_list, importances_list = [], []

In [None]:
intra_features = intra.var_names.to_list()

In [None]:
# loop over each target and build one RF model for each view
for target in targets:
    
    for intra_group in intra_groups:
        intra_obs_msk = intra.obs[group_intra_by] == \
                intra_group if intra_group else np.ones(intra.shape[0], dtype=bool)
        
        # to array
        y = intra[intra_obs_msk, target].X.toarray().reshape(-1)
        
        # intra is always non-self, while other views can be self
        predictors_nonself, insert_index = _check_target_in_predictors(target, intra_features)
        _predictors = predictors if keep_same_predictor else predictors_nonself

        # TODO: rename to target_importances
        importance_dict = {}
        
        # model the intraview
        if not bypass_intra:
            obp_intra, importance_dict["intra"] = _single_view_model(y,
                                                                     intra,
                                                                     intra_obs_msk,
                                                                     predictors_nonself, 
                                                                     n_estimators,
                                                                     n_jobs,
                                                                     seed
                                                                     )
            if insert_index is not None and keep_same_predictor:
                importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

        # loop over the group_views_by
        for extra_group in extra_groups:
            # store the oob predictions for each view to construct predictor matrix for meta model
            oob_list = []

            if not bypass_intra:
                oob_list.append(obp_intra)

            # model the juxta and paraview (if applicable)
            for view_name in [v for v in view_str if v != "intra"]:
                extra = mdata.mod[view_name]
                extra_obs_msk = mdata.obs[group_env_by] == extra_group if extra_group else np.ones(extra.shape[0], dtype=bool)
                
                extra_features = extra.var_names.to_list()
                _predictors, _ = _check_target_in_predictors(target, extra_features)
                
                # NOTE indexing here is expensive, but we do it to avoid memory issues
                connectivity = mdata._get_conn(view_name)
                view = _mask_connectivity(extra, connectivity, extra_obs_msk, _predictors)
                
                oob_predictions, importance_dict[view_name] = \
                    _single_view_model(y,
                                       view,
                                       intra_obs_msk,
                                       _predictors,
                                       n_estimators,
                                       n_jobs,
                                       seed
                                       )
                oob_list.append(oob_predictions)

            # train the meta model with k-fold CV 
            intra_r2, multi_r2, coefs = _multi_model(y,
                                                     np.column_stack(oob_list),
                                                     intra_group, 
                                                     bypass_intra,
                                                     view_str, 
                                                     k_cv,
                                                     alphas, 
                                                     seed
                                                     )
            
            # write the results to a dataframe
            targets_df = _format_targets(target,
                                         intra_group,
                                         extra_group,
                                         view_str,
                                         intra_r2,
                                         multi_r2,
                                         coefs
                                         )
            targets_list.append(targets_df)
            
            importances_df = _format_importances(target, 
                                                _predictors, 
                                                intra_group, 
                                                extra_group,
                                                importance_dict
                                                )
            importances_list.append(importances_df)


# create result dataframes
target_metrics, importances = _concat_dataframes(targets_list,
                                                 importances_list,
                                                 view_str)

In [None]:
target_metrics

In [None]:
importances