In [1]:
from scipy.spatial import cKDTree
from scipy.sparse import identity, issparse

import numpy as np
import pandas as pd
import scanpy as sc

from liana.method.sp._spatial_pipe import spatial_neighbors
from liana.method.sp._misty import _get_distance_weights

from scipy.sparse import csr_matrix
from mudata import MuData

In [2]:
from anndata import AnnData

In [3]:
adata = sc.read_h5ad('liana/tests/data/synthetic.h5ad')

In [4]:
# subset adata
adata = sc.pp.subsample(adata, n_obs=100, copy=True)

In [5]:
from liana.method.sp._misty import _check_features, _get_env_groups,\
    _compose_views_groups, _check_anndata_objects_groups, _check_target_in_predictors, _single_view_model, \
        _multi_model, _format_targets, _format_importances, _concat_dataframes, _check_anndata_objects_groups, _get_neighbors

In [6]:
# mdata, 
# x_mod,
y_mod = None
targets = None
predictors = None
keep_same_predictor = False
bandwidth = None
juxta_cutoff = np.inf
zoi = 0
kernel = "misty_rbf"
set_diag = False
spatial_key = "spatial"
add_juxta = True
add_para = True
bypass_intra = False
group_intra_by = None
group_env_by = None
alphas = [0.1, 1, 10]
k_cv = 10
n_estimators = 100
n_jobs = -1
seed = 1337
inplace = True
overwrite = False

In [7]:
mdata = MuData({'rna': adata})
x_mod = 'rna'

In [8]:
# validate inputs 
if not overwrite and ("misty_results" in mdata.uns.keys()) and inplace:
    raise ValueError("mdata already contains misty results. Set overwrite=True to overwrite.")
if x_mod not in mdata.mod.keys():
    raise ValueError(f"Predictor modality {x_mod} not found in mdata.")
if y_mod is not None and y_mod not in mdata.mod.keys():
    raise ValueError(f"Target modality {y_mod} not found in mdata.")

xdata = mdata[x_mod]
ydata = mdata[y_mod] if y_mod else xdata

_check_anndata_objects_groups(xdata,
                              ydata,
                              spatial_key=spatial_key,
                              group_intra_by=group_intra_by,
                              group_env_by=group_env_by)

In [9]:
group_env_by='cell_type'
group_intra_by='cell_type'

In [10]:
predictors = _check_features(xdata, predictors, type_str="predictors")
targets = _check_features(ydata, targets, type_str="targets")

intra_groups = np.unique(ydata.obs[group_intra_by]) if group_intra_by else [None]
env_groups = np.unique(xdata.obs[group_env_by]) if group_env_by else [None]

TO CHANGE:

In [11]:
intra_group = intra_groups[0]
target = targets[0]

In [12]:
bandwidth = 10

In [13]:
views = _compose_views_groups(xdata,
                              predictors,
                              bypass_intra,
                              add_juxta,
                              add_para,
                              group_env_by,
                              juxta_cutoff, 
                              bandwidth, 
                              kernel,
                              zoi,
                              set_diag,
                              spatial_key
                              )
view_str = list(views.keys())



In [14]:
intra_obs_msk = ydata.obs[group_intra_by] == intra_group if intra_group else np.ones(ydata.shape[0], dtype=bool)

In [15]:
(ydata.obs[group_intra_by] == intra_group)

1282    False
3669     True
142      True
3269    False
3755     True
        ...  
3532     True
2626    False
1708    False
189     False
4141     True
Name: cell_type, Length: 100, dtype: bool

In [16]:
if issparse(ydata.X):
    y = np.asarray(ydata[intra_obs_msk, target].X.todense()).reshape(-1)
else:
    y = ydata[intra_obs_msk, target].X.reshape(-1)

In [17]:
# intra is always non-self, while other views can be self
predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
preds = predictors if keep_same_predictor else predictors_nonself

importance_dict = {}

# model the intraview
if not bypass_intra:
    oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                         views["intra"],
                                                                         intra_obs_msk, 
                                                                         predictors_nonself, 
                                                                         n_estimators,
                                                                         n_jobs, 
                                                                         seed
                                                                         )
    if insert_index is not None and keep_same_predictor:
        importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

In [18]:
env_group = env_groups[0]

In [19]:
view_name = view_str[1]

In [20]:
view_name

'juxta'

In [21]:
view_str

['intra', 'juxta', 'para']

In [22]:
view = views[view_name][env_group] if env_group else views[view_name]["all"]

In [23]:
views[view_name]

{'A': AnnData object with n_obs × n_vars = 100 × 11
     obs: 'cell_type',
 'B': AnnData object with n_obs × n_vars = 100 × 11
     obs: 'cell_type'}

In [24]:
connectivity = _get_neighbors(xdata,
               juxta_cutoff=juxta_cutoff,
               set_diag=set_diag, 
               spatial_key=spatial_key
               )

In [25]:
weights = connectivity.copy()
weights[:, adata.obs[group_env_by]!=env_group] = 0
X = weights @ adata[:, predictors].X
adata = AnnData(X=X, obs=adata.obs, var=pd.DataFrame(index=predictors))



In [None]:

oob_predictions, importance_dict[view_name] = \
    _single_view_model(y, 
                        view, 
                        intra_obs_msk, 
                        preds, 
                        n_estimators,
                        n_jobs,
                        seed
                        )

In [None]:
oob_predictions

In [None]:

# init list to store the results for each intra group and env group as dataframe;
targets_list, importances_list = [], []

# loop over each target and build one RF model for each view
for target in targets:
    
    for intra_group in intra_groups:
        intra_group_bool = ydata.obs[group_intra_by] == intra_group if intra_group else np.ones(ydata.shape[0], dtype=bool)
        
        if issparse(ydata.X):
            y = np.asarray(ydata[intra_group_bool, target].X.todense()).reshape(-1)
        else:
            y = ydata[intra_group_bool, target].X.reshape(-1)

        # intra is always non-self, while other views can be self
        predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
        preds = predictors if keep_same_predictor else predictors_nonself

        importance_dict = {}
        
        # model the intraview
        if not bypass_intra:
            oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                                 views["intra"],
                                                                                 intra_group_bool,
                                                                                 predictors_nonself,
                                                                                 n_estimators,
                                                                                 n_jobs,
                                                                                 seed
                                                                                 )
            if insert_index is not None and keep_same_predictor:
                importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

        # loop over the group_views_by
        for env_group in env_groups:
            
            # store the oob predictions for each view to construct predictor matrix for meta model
            oob_list = []

            if not bypass_intra:
                oob_list.append(oob_predictions_intra)

            # model the juxta and paraview (if applicable)
            ## TODO: remove this thing with all
            for view_name in [v for v in view_str if v != "intra"]:
                view = views[view_name][env_group] if env_group else views[view_name]["all"]
                oob_predictions, importance_dict[view_name] = \
                    _single_view_model(y, 
                                        view, 
                                        intra_group_bool, 
                                        preds, 
                                        n_estimators,
                                        n_jobs,
                                        seed
                                        )
                oob_list.append(oob_predictions)

            # train the meta model with k-fold CV 
            intra_r2, multi_r2, coefs = _multi_model(y,
                                                        np.column_stack(oob_list),
                                                        intra_group, 
                                                        bypass_intra, 
                                                        view_str, 
                                                        k_cv, 
                                                        alphas, 
                                                        seed
                                                        )
            
            targets_df = _format_targets(target,
                                            intra_group,
                                            env_group,
                                            view_str,
                                            intra_r2,
                                            multi_r2,
                                            coefs
                                        )
            targets_list.append(targets_df)
            
            importances_df = _format_importances(target, 
                                                    preds, 
                                                    intra_group, 
                                                    env_group,
                                                    importance_dict
                                                    )
            importances_list.append(importances_df)


# create result dataframes
target_metrics, importances = _concat_dataframes(targets_list,
                                                 importances_list,
                                                 view_str)

Rewrite misty

In [27]:
views = _compose_views_groups(xdata, 
                                predictors,
                                bypass_intra, 
                                add_juxta, 
                                add_para, 
                                group_env_by, 
                                juxta_cutoff, 
                                bandwidth, 
                                kernel,
                                zoi,
                                set_diag, 
                                spatial_key)
view_str = list(views.keys())

# init list to store the results for each intra group and env group as dataframe;
targets_list, importances_list = [], []




In [28]:
connectivities = {}

In [30]:
env_obs_msk = ydata.obs[group_env_by] == env_group if env_group else np.ones(xdata.shape[0], dtype=bool)

In [31]:
weights = connectivities[view_name]
weights[:, adata.obs[group_env_by]!=env_group] = 0
X = weights @ adata[:, predictors].X




In [37]:
def _mask_connectivity(xdata, connectivity, env_group):
    
    weights = connectivity.copy()
    weights[:, xdata.obs[group_env_by]!=env_group] = 0
    X = weights @ xdata[:, predictors].X
    view = AnnData(X=X, obs=xdata.obs, var=pd.DataFrame(index=predictors))
    
    return view



AnnData object with n_obs × n_vars = 100 × 11
    obs: 'cell_type'

In [40]:
if add_juxta:
    connectivities['juxta'] = _get_neighbors(xdata,
                                            juxta_cutoff=juxta_cutoff,
                                            set_diag=set_diag, 
                                            spatial_key=spatial_key
                                            )
if add_para:
    connectivities['para'] = distance_weights = spatial_neighbors(adata=xdata,
                                                                bandwidth=bandwidth, 
                                                                kernel=kernel,
                                                                set_diag=set_diag, 
                                                                inplace=False,
                                                                cutoff=0, 
                                                                zoi=zoi
                                                                )
    
# loop over each target and build one RF model for each view
for target in targets:
    
    for intra_group in intra_groups:
        intra_obs_msk = ydata.obs[group_intra_by] == \
                intra_group if intra_group else np.ones(ydata.shape[0], dtype=bool)
        
        if issparse(ydata.X):
            y = np.asarray(ydata[intra_obs_msk, target].X.todense()).reshape(-1)
        else:
            y = ydata[intra_obs_msk, target].X.reshape(-1)

        # intra is always non-self, while other views can be self
        predictors_nonself, insert_index = _check_target_in_predictors(target, predictors)
        preds = predictors if keep_same_predictor else predictors_nonself

        importance_dict = {}
        
        # model the intraview
        if not bypass_intra:
            oob_predictions_intra, importance_dict["intra"] = _single_view_model(y,
                                                                                 views["intra"], 
                                                                                 intra_obs_msk, 
                                                                                 predictors_nonself, 
                                                                                 n_estimators, 
                                                                                 n_jobs, 
                                                                                 seed
                                                                                 )
            if insert_index is not None and keep_same_predictor:
                importance_dict["intra"] = np.insert(importance_dict["intra"], insert_index, np.nan)

        # loop over the group_views_by
        for env_group in env_groups:
            
            # store the oob predictions for each view to construct predictor matrix for meta model
            oob_list = []

            if not bypass_intra:
                oob_list.append(oob_predictions_intra)

            # model the juxta and paraview (if applicable)
            ## TODO: remove this thing with all
            for view_name in [v for v in view_str if v != "intra"]:
                connectivity = connectivities[view_name]
                view = _mask_connectivity(xdata, connectivity, env_group)
                oob_predictions, importance_dict[view_name] = \
                    _single_view_model(y, 
                                        view, 
                                        intra_obs_msk, 
                                        preds, 
                                        n_estimators,
                                        n_jobs,
                                        seed
                                        )
                oob_list.append(oob_predictions)


            # train the meta model with k-fold CV 
            intra_r2, multi_r2, coefs = _multi_model(y,
                                                        np.column_stack(oob_list),
                                                        intra_group, 
                                                        bypass_intra, 
                                                        view_str, 
                                                        k_cv, 
                                                        alphas, 
                                                        seed
                                                        )
            
            targets_df = _format_targets(target,
                                         intra_group,
                                         env_group,
                                         view_str,
                                         intra_r2,
                                         multi_r2,
                                         coefs
                                        )
            targets_list.append(targets_df)
            
            importances_df = _format_importances(target, 
                                                    preds, 
                                                    intra_group, 
                                                    env_group,
                                                    importance_dict
                                                    )
            importances_list.append(importances_df)


# create result dataframes
target_metrics, importances = _concat_dataframes(targets_list,
                                                    importances_list,
                                                    view_str)



In [41]:
importances

Unnamed: 0,target,predictor,intra_group,env_group,view,value
0,ECM,ligA,A,A,intra,0.044582
1,ECM,ligB,A,A,intra,0.032802
2,ECM,ligC,A,A,intra,0.090857
3,ECM,ligD,A,A,intra,0.043156
4,ECM,protE,A,A,intra,0.677274
...,...,...,...,...,...,...
2635,prodD,protE,B,B,para,0.147171
2636,prodD,protF,B,B,para,0.046902
2637,prodD,prodA,B,B,para,0.037213
2638,prodD,prodB,B,B,para,0.039037


In [34]:
importances

Unnamed: 0,target,predictor,intra_group,env_group,view,value
0,ECM,ligA,A,A,intra,0.044582
1,ECM,ligB,A,A,intra,0.032802
2,ECM,ligC,A,A,intra,0.090857
3,ECM,ligD,A,A,intra,0.043156
4,ECM,protE,A,A,intra,0.677274
...,...,...,...,...,...,...
1315,prodD,protE,B,B,para,0.147171
1316,prodD,protF,B,B,para,0.046902
1317,prodD,prodA,B,B,para,0.037213
1318,prodD,prodB,B,B,para,0.039037


In [35]:
target_metrics

Unnamed: 0,target,intra_group,env_group,intra.R2,multi.R2,intra,juxta,para,gain.R2
0,ECM,A,A,0.217495,0.21479,0.885643,0.0,0.114357,-0.002705
1,ECM,A,B,0.217495,0.227228,0.822258,0.125818,0.051924,0.009732
2,ECM,B,A,-6.480144,-6.795081,1.0,0.0,0.0,-0.314938
3,ECM,B,B,-6.480144,-6.388007,0.868735,0.046875,0.084389,0.092137
4,ligA,A,A,0.433328,0.468598,0.698738,0.114799,0.186463,0.035269
5,ligA,A,B,0.433328,0.377478,0.876626,0.10068,0.022694,-0.055851
6,ligA,B,A,0.845237,0.81261,1.0,0.0,0.0,-0.032627
7,ligA,B,B,0.845237,0.840463,1.0,0.0,0.0,-0.004774
8,ligB,A,A,-1.708177,-1.730911,0.947551,0.0,0.052449,-0.022734
9,ligB,A,B,-1.708177,-1.730489,1.0,0.0,0.0,-0.022311
