In [1]:
from scipy.spatial import cKDTree
from scipy.sparse import identity, issparse

import numpy as np
import pandas as pd
import scanpy as sc

from liana.method.sp._spatial_pipe import spatial_neighbors
from liana.method.sp._misty import _get_neighbors

from scipy.sparse import csr_matrix
from mudata import MuData
from anndata import AnnData
import os

In [2]:
ydata = None

In [3]:
xdata = sc.read_h5ad('liana/tests/data/synthetic.h5ad')

In [4]:
ydata = ydata if ydata else xdata

In [5]:
from liana.method._pipe_utils import prep_check_adata

In [8]:
from liana.mt.sp._misty import _check_target_in_predictors, _check_features, _single_view_model, _mask_connectivity, _multi_model, _format_targets, _format_importances, _concat_dataframes

In [9]:
juxta_cutoff = np.inf
set_diag = True
spatial_key = 'spatial'
bandwidth = 10
kernel = 'misty_rbf'
zoi = 0

In [11]:
# _check_anndata_objects_groups(xdata,
#                               ydata,
#                               spatial_key=spatial_key,
#                               group_intra_by=group_intra_by,
#                               group_env_by=group_env_by
#                               )

# predictors = _check_features(xdata, predictors, type_str="predictors")
# targets = _check_features(ydata, targets, type_str="targets")

In [12]:
def _make_view(adata, connecitivity=None, spatial_key=None, layer=None, use_raw=False):
    adata = prep_check_adata(adata, use_raw=use_raw, layer=layer, groupby=None, min_cells=None)    
    if connecitivity is not None:
        conns = connecitivity
        obsp = {'spatial_connectivities': conns}
    elif spatial_key is not None:
        conns = adata.obsp[spatial_key]
        obsp = {'spatial_connectivities': conns}
    else:
        obsp=None
        
    return AnnData(adata.X, obs=adata.obs, obsp=obsp)

In [13]:
# # validate inputs 
# if not overwrite and ("misty_results" in mdata.uns.keys()) and inplace:
#     raise ValueError("mdata already contains misty results. Set overwrite=True to overwrite.")
# if x_mod not in mdata.mod.keys():
#     raise ValueError(f"Predictor modality {x_mod} not found in mdata.")
# if y_mod is not None and y_mod not in mdata.mod.keys():
#     raise ValueError(f"Target modality {y_mod} not found in mdata.")
# if add_para and bandwidth is None:
#     raise ValueError("bandwith must be specified if add_para=True")

In [14]:
intra = _make_view(adata=ydata, use_raw=False, layer=None)

In [15]:
neighbors = _get_neighbors(adata=xdata,
                           juxta_cutoff=juxta_cutoff,
                           set_diag=set_diag, 
                           spatial_key=spatial_key
                           )

juxta = _make_view(xdata, use_raw=False, layer=None, connecitivity=neighbors)

In [16]:
weights = spatial_neighbors(adata=xdata,
                            bandwidth=bandwidth, 
                            kernel=kernel,
                            set_diag=set_diag, 
                            inplace=False,
                            cutoff=0,
                            zoi=zoi
                            )
para = _make_view(xdata, use_raw=False, layer=None, connecitivity=weights)

In [17]:
views = MuData({'intra':intra, 'juxta':juxta, 'para':para})



In [18]:
views.mod.keys()

dict_keys(['intra', 'juxta', 'para'])

In [None]:
## TODO: two applications
# juxta, para
# misty_lr
# anything else would require a new constructor

#### FIT params

In [40]:
n_estimators = 100
n_jobs = -1
seed = 1337
bypass_intra = True
keep_same_predictor = False
k_cv = 10
alphas = [0.1, 1, 10]

In [41]:
group_intra_by = None
group_env_by = None

In [42]:
predictors = None
targets = None

In [43]:
# TODO are to be moved when constructing the views
predictors = _check_features(xdata, predictors, type_str="predictors")
targets = _check_features(ydata, targets, type_str="targets")

intra_groups = np.unique(ydata.obs[group_intra_by]) if group_intra_by else [None]
env_groups = np.unique(xdata.obs[group_env_by]) if group_env_by else [None]

In [49]:
view_str = list(views.mod.keys())
if bypass_intra:
    view_str.remove('intra')

In [50]:
# init list to store the results for each intra group and env group as dataframe;
targets_list, importances_list = [], []

In [51]:
# loop over each target and build one RF model for each view
for target in targets:
    
    for intra_group in intra_groups:
        intra_obs_msk = ydata.obs[group_intra_by] == \
                intra_group if intra_group else np.ones(ydata.shape[0], dtype=bool)
        
        if issparse(ydata.X):
            y = np.asarray(ydata[intra_obs_msk, target].X.todense()).reshape(-1)
        else:
            y = ydata[intra_obs_msk, target].X.reshape(-1)

        # intra is always non-self, while other views can be self
        predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
        preds = predictors if keep_same_predictor else predictors_nonself

        importance_dict = {}
        
        # model the intraview
        if not bypass_intra:
            oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                                 ydata,
                                                                                 intra_obs_msk, 
                                                                                 predictors_nonself, 
                                                                                 n_estimators, 
                                                                                 n_jobs, 
                                                                                 seed
                                                                                 )
            if insert_index is not None and keep_same_predictor:
                importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

        # loop over the group_views_by
        for env_group in env_groups:
            # store the oob predictions for each view to construct predictor matrix for meta model
            oob_list = []
            
            env_obs_msk = ydata.obs[group_env_by] == env_group if env_group else np.ones(xdata.shape[0], dtype=bool)

            if not bypass_intra:
                oob_list.append(oob_predictions_intra)

            # model the juxta and paraview (if applicable)
            ## TODO: remove this thing with all
            for view_name in [v for v in view_str if v != "intra"]:
                connectivity = views.mod[view_name].obsp["spatial_connectivities"]
                # NOTE indexing here is expensive, but we do it to avoid memory issues
                view = _mask_connectivity(xdata, connectivity, env_obs_msk, predictors)
                
                oob_predictions, importance_dict[view_name] = \
                    _single_view_model(y,
                                    view, 
                                    intra_obs_msk, 
                                    preds, 
                                    n_estimators,
                                    n_jobs,
                                    seed
                                    )
                oob_list.append(oob_predictions)


            # train the meta model with k-fold CV 
            intra_r2, multi_r2, coefs = _multi_model(y,
                                                    np.column_stack(oob_list),
                                                    intra_group, 
                                                    bypass_intra, 
                                                    view_str, 
                                                    k_cv, 
                                                    alphas, 
                                                    seed
                                                    )
            
            targets_df = _format_targets(target,
                                        intra_group,
                                        env_group,
                                        view_str,
                                        intra_r2,
                                        multi_r2,
                                        coefs
                                        )
            targets_list.append(targets_df)
            
            importances_df = _format_importances(target, 
                                                preds, 
                                                intra_group, 
                                                env_group,
                                                importance_dict
                                                )
            importances_list.append(importances_df)


# create result dataframes
target_metrics, importances = _concat_dataframes(targets_list,
                                                importances_list,
                                                view_str)

In [None]:
class MistyConstructor():
    def __init__(self, views):
        self.views = views
        self.view_names = views.mod.keys()
        self._check_views()
        
    def _get_view(self, view_name):
        return self.views.mod[view_name]
    
    def _get_view_names(self):
        return self.view_names
    
    def _check_views(self):
        views = self.views
        assert isinstance(views, MuData), "views must be a MuData object"
        
        for view in self.view_names:
            if "spatial_connectivities" not in views.mod[view].obsp.keys():
                raise ValueError(f"view {view} must contain a spatial_connectivities key in .obsp")
            
    def _get_connectivity(self, view_name):
        return self._get_view(view_name).obsp["spatial_connectivities"]
    
    def _get_X(self, view_name):
        return self._get_view(view_name).X
    
    def __call__(self,
                 adata,
                 targets = None,
                 predictors = None,
                 keep_same_predictor = False,  # TODO: maybe rename this variable
                 spatial_key = "spatial", 
                 add_juxta = True,
                 add_para = True,
                 bypass_intra = False,
                 group_intra_by = None,
                 group_env_by = None,
                 alphas = [0.1, 1, 10],
                 k_cv = 10,
                 n_estimators = 100,
                 n_jobs = -1,
                 seed = 1337,
                 inplace = True,
                 ):
        

        intra_groups = np.unique(ydata.obs[group_intra_by]) if group_intra_by else [None]
        env_groups = np.unique(xdata.obs[group_env_by]) if group_env_by else [None]

        connectivities = {}
        if add_juxta:
            connectivities['juxta'] = _get_neighbors(xdata,
                                                     juxta_cutoff=juxta_cutoff,
                                                     set_diag=set_diag, 
                                                     spatial_key=spatial_key
                                                     )
        if add_para:
            connectivities['para'] = spatial_neighbors(adata=xdata,
                                                    bandwidth=bandwidth, 
                                                    kernel=kernel,
                                                    set_diag=set_diag, 
                                                    inplace=False,
                                                    cutoff=0,
                                                    zoi=zoi
                                                    )
        view_str = list(connectivities.keys())
        if not bypass_intra:
            view_str = ['intra'] + view_str

        # init list to store the results for each intra group and env group as dataframe;
        targets_list, importances_list = [], []

        # loop over each target and build one RF model for each view
        for target in targets:
            
            for intra_group in intra_groups:
                intra_obs_msk = ydata.obs[group_intra_by] == \
                        intra_group if intra_group else np.ones(ydata.shape[0], dtype=bool)
                
                if issparse(ydata.X):
                    y = np.asarray(ydata[intra_obs_msk, target].X.todense()).reshape(-1)
                else:
                    y = ydata[intra_obs_msk, target].X.reshape(-1)

                # intra is always non-self, while other views can be self
                predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
                preds = predictors if keep_same_predictor else predictors_nonself

                importance_dict = {}
                
                # model the intraview
                if not bypass_intra:
                    oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                                         ydata, 
                                                                                         intra_obs_msk, 
                                                                                         predictors_nonself, 
                                                                                         n_estimators, 
                                                                                         n_jobs, 
                                                                                         seed
                                                                                         )
                    if insert_index is not None and keep_same_predictor:
                        importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

                # loop over the group_views_by
                for env_group in env_groups:
                    # store the oob predictions for each view to construct predictor matrix for meta model
                    oob_list = []
                    
                    env_obs_msk = ydata.obs[group_env_by] == env_group if env_group else np.ones(xdata.shape[0], dtype=bool)

                    if not bypass_intra:
                        oob_list.append(oob_predictions_intra)

                    # model the juxta and paraview (if applicable)
                    ## TODO: remove this thing with all
                    for view_name in [v for v in view_str if v != "intra"]:
                        connectivity = connectivities[view_name]
                        # NOTE indexing here is expensive, but we do it to avoid memory issues
                        view = _mask_connectivity(xdata, connectivity, env_obs_msk, predictors)
                        
                        oob_predictions, importance_dict[view_name] = \
                            _single_view_model(y,
                                            view, 
                                            intra_obs_msk, 
                                            preds, 
                                            n_estimators,
                                            n_jobs,
                                            seed
                                            )
                        oob_list.append(oob_predictions)


                    # train the meta model with k-fold CV 
                    intra_r2, multi_r2, coefs = _multi_model(y,
                                                            np.column_stack(oob_list),
                                                            intra_group, 
                                                            bypass_intra, 
                                                            view_str, 
                                                            k_cv, 
                                                            alphas, 
                                                            seed
                                                            )
                    
                    targets_df = _format_targets(target,
                                                intra_group,
                                                env_group,
                                                view_str,
                                                intra_r2,
                                                multi_r2,
                                                coefs
                                                )
                    targets_list.append(targets_df)
                    
                    importances_df = _format_importances(target, 
                                                        preds, 
                                                        intra_group, 
                                                        env_group,
                                                        importance_dict
                                                        )
                    importances_list.append(importances_df)


        # create result dataframes
        target_metrics, importances = _concat_dataframes(targets_list,
                                                        importances_list,
                                                        view_str)
        if inplace:
            adata.uns["misty_results"] = {"target_metrics": target_metrics,
                                        "importances": importances
                                        }
        else:
            return {"target_metrics": target_metrics,
                    "importances": importances}

In [None]:
Class Misty():

