In [72]:
import itertools
import math
import os
import pickle
import re
from collections import OrderedDict

import h5py
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn import metrics
from sklearn.metrics.pairwise import paired_cosine_distances

from src.preprocessing_mouse_GSE115746 import (
    cell_cluster_cell_type_to_spot_composition,
    cell_subclass_to_spot_composition,
)
from src.preprocessing_spotless import get_st_sub_map



In [2]:
splits = ['train', 'val', 'test']

dset_dir = "preprocessed_data/dlpfc/all/raw_counts"


In [4]:
with open(os.path.join(dset_dir,"st_sample_id_l.pkl"), "rb") as f:
    st_sample_id_l = pickle.load(f)

mat_sp_samp_split_d = sc.read_h5ad(os.path.join(dset_dir, "unscaled/mat_sp_samp_split_d.h5ad"))
st_splits = mat_sp_samp_split_d.obs['split'].unique().tolist()

with open(os.path.join(dset_dir,"sc_sub_dict2.pkl"), "rb") as f:
    sc_sub_dict2 = pickle.load(f)

with open(os.path.join(dset_dir,"sc_sub_dict.pkl"), "rb") as f:
    sc_sub_dict = pickle.load(f)

st_sample_id_d = OrderedDict()
for split in splits:
    sids = set(mat_sp_samp_split_d.obs.loc[:,"sample_id"][mat_sp_samp_split_d.obs.loc[:,"split"] == split])
    if split not in st_splits:
        continue
    st_sample_id_d[split] = [
        sid for sid in st_sample_id_l if sid in sids
    ]


In [5]:
with h5py.File(os.path.join(dset_dir, "st_predictions.h5"), "r") as f:

    predictions = f["predictions"][()]

col_names = pd.read_csv(
    os.path.join(dset_dir, "pred_columns.csv"), header=None, index_col=None
)[0]
row_names = pd.read_csv(
    os.path.join(dset_dir, "pred_rows.csv"), header=None, index_col=None
)[0]

predictions = pd.DataFrame(predictions, index=row_names, columns=col_names)

predictions.index = predictions.index.str.replace(".", "-", regex=False)
predictions.columns = predictions.columns.str.replace("Slash", "/", regex=False)

cell_types = [sc_sub_dict[i] for i in range(len(sc_sub_dict))]
predictions = predictions.reindex(columns=cell_types, fill_value=0.0)
predictions = predictions.reindex(index=mat_sp_samp_split_d.obs.index, fill_value=np.nan)


In [6]:
pred_sp_d = OrderedDict()

for split in splits:
    split_sids = st_sample_id_d[split]
    for split_sid in split_sids:
        spots = mat_sp_samp_split_d.obs.index[mat_sp_samp_split_d.obs['sample_id'] == split_sid]

        pred_sp_d[split_sid] = predictions.loc[spots].to_numpy()


In [7]:
Ex_to_L_d = {
    1: {5, 6},
    2: {5},
    3: {4, 5},
    4: {6},
    5: {5},
    6: {4, 5, 6},
    7: {4, 5, 6},
    8: {5, 6},
    9: {5, 6},
    10: {2, 3, 4},
}

def _plot_roc(
    visnum,
    adata,
    pred_sp,
    num_name_exN_l,
    numlist,
):
    """Plot ROC for a given visnum"""

    Ex_l = [t[2] for t in num_name_exN_l]
    num_to_ex_d = dict(zip(numlist, Ex_l))

    def layer_to_layer_number(x):
        """Converts a string of layers to a list of layer numbers"""
        for char in x:
            if char.isdigit():
                # if in (ordinal -> ex number -> layers)
                if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
                    return 1
        return 0

    y_pred = pred_sp[:, visnum]

    # set unpredicted to 0
    # This means to include unpredicted spots, right side (FPR=1) of ROC will 
    # represent threshold of 0; TPR will not include unpredicted spots until
    # that point, removing proportion of unpredicted from the AUROC.
    y_pred[~np.isfinite(y_pred)] = 0.0
    y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)

    if y_true.sum() > 0:


        return metrics.roc_auc_score(y_true, y_pred)

    return np.nan

In [18]:
num_name_exN_l = []
for k, v in sc_sub_dict.items():
    if "Ex" in v:
        # (clust_ordinal, clust_name, Ex_clust_num)
        num_name_exN_l.append((k, v, int(v.split("_")[1])))

num_name_exN_l.sort(key=lambda a: a[2])

numlist = [t[0] for t in num_name_exN_l]  # clust ordinals






spot_scores = OrderedDict()

for split in splits:
    split_sids = st_sample_id_d[split]
    for sample_id in split_sids:
        # spot_scores[split] = OrderedDict()
        da_aucs = []
        for i, num in enumerate(numlist):
            ax_ = None
            da_aucs.append(
                _plot_roc(
                    num,
                    mat_sp_samp_split_d[mat_sp_samp_split_d.obs["sample_id"] == sample_id],
                    pred_sp_d[sample_id],
                    num_name_exN_l,
                    numlist,
                )
            )
    
        spot_scores[(split, sample_id)] = np.nanmean(da_aucs)





In [24]:
scores_d = OrderedDict()

scores_d["dlpfc"] = pd.Series(spot_scores).groupby(level=0).mean().reindex(splits)

OrderedDict([('dlpfc',
              train    0.553414
              val      0.482188
              test     0.497089
              dtype: float64)])

In [26]:
splits = ['train']

dset_dir = "preprocessed_data/pdac/all/raw_counts"


In [27]:
with open(os.path.join(dset_dir,"st_sample_id_l.pkl"), "rb") as f:
    st_sample_id_l = pickle.load(f)

mat_sp_samp_split_d = sc.read_h5ad(os.path.join(dset_dir, "unscaled/mat_sp_train_d.h5ad"))
st_splits = mat_sp_samp_split_d.obs['split'].unique().tolist()

with open(os.path.join(dset_dir,"sc_sub_dict2.pkl"), "rb") as f:
    sc_sub_dict2 = pickle.load(f)

with open(os.path.join(dset_dir,"sc_sub_dict.pkl"), "rb") as f:
    sc_sub_dict = pickle.load(f)

st_sample_id_d = OrderedDict()
for split in splits:
    sids = set(mat_sp_samp_split_d.obs.loc[:,"sample_id"][mat_sp_samp_split_d.obs.loc[:,"split"] == split])
    if split not in st_splits:
        continue
    st_sample_id_d[split] = [
        sid for sid in st_sample_id_l if sid in sids
    ]


In [28]:
with h5py.File(os.path.join(dset_dir, "st_predictions.h5"), "r") as f:

    predictions = f["predictions"][()]

col_names = pd.read_csv(
    os.path.join(dset_dir, "pred_columns.csv"), header=None, index_col=None
)[0]
row_names = pd.read_csv(
    os.path.join(dset_dir, "pred_rows.csv"), header=None, index_col=None
)[0]

predictions = pd.DataFrame(predictions, index=row_names, columns=col_names)

predictions.index = predictions.index.str.replace(".", "-", regex=False)
predictions.columns = predictions.columns.str.replace("Slash", "/", regex=False)

cell_types = [sc_sub_dict[i] for i in range(len(sc_sub_dict))]
predictions = predictions.reindex(columns=cell_types, fill_value=0.0)
predictions = predictions.reindex(index=mat_sp_samp_split_d.obs.index, fill_value=np.nan)


In [29]:
pred_sp_d = OrderedDict()

for split in splits:
    split_sids = st_sample_id_d[split]
    for split_sid in split_sids:
        spots = mat_sp_samp_split_d.obs.index[mat_sp_samp_split_d.obs['sample_id'] == split_sid]

        pred_sp_d[split_sid] = predictions.loc[spots].to_numpy()


In [36]:
def _plot_roc_pdac(visnum, adata, pred_sp, sc_to_st_celltype):
    """Plot ROC for a given visnum (PDAC)"""
    try:
        cell_name = sc_sub_dict[visnum]
    except TypeError:
        cell_name = "Other"

    def st_sc_bin(cell_type):
        return int(cell_type in sc_to_st_celltype.get(cell_name, set()))

    y_pred = pred_sp[:, visnum].squeeze()
    y_pred = np.nan_to_num(y_pred, nan=0.0)
    if y_pred.ndim > 1:
        y_pred = y_pred.sum(axis=1)
    y_true = adata.obs["cell_type"].map(st_sc_bin).fillna(0)

    if y_true.sum() > 0:
        return metrics.roc_auc_score(y_true, y_pred)

    return np.nan

In [42]:
sc_to_st_celltype = {
    # Peng et al., 2019: Taken together, these results show that
    # type 2 ductal cells are the major source of malignant cells in
    # PDACs.
    "Ductal cell type 2": {"Cancer region"},
    "T cell": {"Cancer region", "Stroma"},
    "Macrophage cell": {"Cancer region", "Stroma"},
    "Fibroblast cell": {"Cancer region", "Stroma"},
    "B cell": {"Cancer region", "Stroma"},
    "Ductal cell type 1": {"Duct epithelium"},
    "Endothelial cell": {"Interstitium"},
    "Stellate cell": {"Stroma", "Pancreatic tissue"},
    "Acinar cell": {"Pancreatic tissue"},
    "Endocrine cell": {"Pancreatic tissue"},
}


celltypes = list(sc_to_st_celltype.keys()) + ["Other"]
n_celltypes = len(celltypes)
n_rows = int(math.ceil(n_celltypes / 5))

numlist = [sc_sub_dict2.get(t) for t in celltypes[:-1]]
numlist.extend([v for k, v in sc_sub_dict2.items() if k not in celltypes[:-1]])



spot_scores = OrderedDict()

for split in splits:
    split_sids = st_sample_id_d[split]
    spot_scores[split] = OrderedDict()
    for sample_id in split_sids:
        da_aucs = []
        for i, num in enumerate(numlist[:-1]):
            da_aucs.append(
                _plot_roc_pdac(
                    num,
                    mat_sp_samp_split_d[mat_sp_samp_split_d.obs["sample_id"] == sample_id],
                    pred_sp_d[sample_id],
                    sc_to_st_celltype,
                )
            )





        spot_scores[split][sample_id] =  np.nanmean(da_aucs)



In [43]:
spot_scores

OrderedDict([('train', OrderedDict([('pdac_a', 0.5), ('pdac_b', 0.5)]))])

In [55]:
scores_d["pdac"] = pd.concat({k: pd.Series(v) for k, v in spot_scores.items()}).groupby(level=0).mean() #.reindex(splits)

In [56]:
scores_d["pdac"]

train    0.5
dtype: float64

In [64]:
splits = ['train', 'val', 'test']

dset_dir = "preprocessed_data/spotless/all/raw_counts"


In [65]:
with open(os.path.join(dset_dir,"st_sample_id_l.pkl"), "rb") as f:
    st_sample_id_l = pickle.load(f)

mat_sp_samp_split_d = sc.read_h5ad(os.path.join(dset_dir, "unscaled/mat_sp_samp_split_d.h5ad"))
st_splits = mat_sp_samp_split_d.obs['split'].unique().tolist()

with open(os.path.join(dset_dir,"sc_sub_dict2.pkl"), "rb") as f:
    sc_sub_dict2 = pickle.load(f)

with open(os.path.join(dset_dir,"sc_sub_dict.pkl"), "rb") as f:
    sc_sub_dict = pickle.load(f)

st_sample_id_d = OrderedDict()
for split in splits:
    sids = set(mat_sp_samp_split_d.obs.loc[:,"sample_id"][mat_sp_samp_split_d.obs.loc[:,"split"] == split])
    if split not in st_splits:
        continue
    st_sample_id_d[split] = [
        sid for sid in st_sample_id_l if sid in sids
    ]


In [66]:
with h5py.File(os.path.join(dset_dir, "st_predictions.h5"), "r") as f:

    predictions = f["predictions"][()]

col_names = pd.read_csv(
    os.path.join(dset_dir, "pred_columns.csv"), header=None, index_col=None
)[0]
row_names = pd.read_csv(
    os.path.join(dset_dir, "pred_rows.csv"), header=None, index_col=None
)[0]

predictions = pd.DataFrame(predictions, index=row_names, columns=col_names)

predictions.index = predictions.index.str.replace(".", "-", regex=False)
predictions.columns = predictions.columns.str.replace("Slash", "/", regex=False)

cell_types = [sc_sub_dict[i] for i in range(len(sc_sub_dict))]
predictions = predictions.reindex(columns=cell_types, fill_value=0.0)
predictions = predictions.reindex(index=mat_sp_samp_split_d.obs.index, fill_value=np.nan)


In [68]:
predictions

Unnamed: 0,Astro,Batch Grouping,CR,Doublet Astro Aqp4 Ex,Doublet Endo,Doublet VISp L5 NP and L6 CT,Endo,High Intronic,L2/3 IT,L4,...,NP,Oligo,Peri,Pvalb,SMC,Serpinf1,Sncg,Sst,VLMC,Vip
spot_1,0.065574,0.000094,0.0,0.0,0.021691,0.000094,0.046968,0.203390,0.131708,0.000094,...,0.009127,0.045656,0.033060,0.041988,0.034747,0.005776,0.016008,0.032795,0.067503,0.054646
spot_2,0.088842,0.005988,0.0,0.0,0.034059,0.000093,0.064074,0.180086,0.072099,0.000093,...,0.014626,0.053172,0.028659,0.018539,0.033710,0.005619,0.081648,0.033538,0.068744,0.065875
spot_3,0.283756,0.000045,0.0,0.0,0.000055,0.000045,0.066877,0.090290,0.011668,0.000045,...,0.003812,0.058927,0.028577,0.000045,0.053665,0.000045,0.005422,0.039513,0.066950,0.061182
spot_4,0.059405,0.000097,0.0,0.0,0.031219,0.000097,0.046600,0.197762,0.185305,0.048333,...,0.004197,0.053192,0.032850,0.010993,0.031830,0.022006,0.012809,0.033963,0.053180,0.043159
spot_5,0.060381,0.020950,0.0,0.0,0.023918,0.000095,0.043542,0.221648,0.111834,0.000095,...,0.014673,0.054637,0.028244,0.018600,0.039960,0.011814,0.004062,0.029430,0.063621,0.068352
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
spot_5-3,0.061157,0.000093,0.0,0.0,0.018658,0.015643,0.055937,0.238554,0.004464,0.000096,...,0.017319,0.090831,0.043767,0.007020,0.030524,0.015671,0.011555,0.023536,0.059534,0.023090
spot_6-3,0.069513,0.000093,0.0,0.0,0.013734,0.000093,0.052950,0.233804,0.000093,0.001807,...,0.020554,0.051433,0.032784,0.018034,0.030266,0.025363,0.014803,0.056242,0.051951,0.026492
spot_7-3,0.056943,0.000094,0.0,0.0,0.017736,0.012482,0.047838,0.249264,0.034728,0.000095,...,0.012017,0.047171,0.030835,0.064733,0.029037,0.009194,0.016414,0.039645,0.060597,0.032436
spot_8-3,0.058560,0.000094,0.0,0.0,0.000972,0.018618,0.073867,0.244776,0.026950,0.000105,...,0.003171,0.053187,0.035445,0.007027,0.041457,0.000593,0.011494,0.036121,0.062051,0.051379


In [67]:
pred_sp_d = OrderedDict()

for split in splits:
    split_sids = st_sample_id_d[split]
    for split_sid in split_sids:
        spots = mat_sp_samp_split_d.obs.index[mat_sp_samp_split_d.obs['sample_id'] == split_sid]

        pred_sp_d[split_sid] = predictions.loc[spots].to_numpy()


In [69]:
st_sample_id_d

OrderedDict([('train',
              ['Eng2019_cortex_svz_fov0',
               'Eng2019_cortex_svz_fov1',
               'Eng2019_cortex_svz_fov4',
               'Eng2019_cortex_svz_fov5',
               'Eng2019_cortex_svz_fov6']),
             ('val', ['Eng2019_cortex_svz_fov2']),
             ('test', ['Eng2019_cortex_svz_fov3'])])

In [73]:
def _merge_sc_preds(pred_sp_d, cell_type_index, merged_to_sc):
    new_pred_dict = {}
    for sid, pred in pred_sp_d.items():
        new_pred_dict[sid] = np.empty((pred.shape[0], len(cell_type_index)), dtype=pred.dtype)
        for i, merged_cell_type in enumerate(cell_type_index):
            if merged_cell_type in merged_to_sc:
                old_idxs = [
                    sc_sub_dict2[cell_type] for cell_type in merged_to_sc[merged_cell_type]
                ]
                new_pred_dict[sid][:, i] = pred[:, old_idxs].sum(axis=1)
            else:
                # no sc cell types map to this st cell type
                new_pred_dict[sid][:, i] = 0

    return new_pred_dict

In [85]:



hue="relative_spot_composition"
st_sub_map = get_st_sub_map()
cell_type_index = sorted(list(st_sub_map.keys())) + ["Other"]

# create a mapping from spotless cell types to sc cell types
merged_to_sc = {k: [] for k in cell_type_index}

for k, v in itertools.chain(
    cell_cluster_cell_type_to_spot_composition.items(),
    cell_subclass_to_spot_composition.items(),
):
    if k != "keep_the_rest":
        if len(v) > 0:
            merged_to_sc["/".join(sorted(list(v)))].append(k)
        else:
            merged_to_sc["Other"].append(k)

new_pred_sp_d = _merge_sc_preds(pred_sp_d, cell_type_index, merged_to_sc)

st_cell_types_to_sc = {re.sub("( |\/)", ".", name): name for name in cell_type_index}

ctps = OrderedDict()
for split in splits:
    ctps[split] = OrderedDict()
    for sample_id in st_sample_id_d[split]:
        if sample_id not in new_pred_sp_d:
            continue
        dists_true = (
            mat_sp_samp_split_d[mat_sp_samp_split_d.obs["sample_id"] == sample_id]
            .obsm[hue]
            .rename(columns=st_cell_types_to_sc)
            # this adds an "Other" column with all 0s
            .reindex(columns=cell_type_index, fill_value=0.0)
            .to_numpy()
        )

        mask = np.all(np.isfinite(new_pred_sp_d[sample_id]), axis=1)

        y_pred = new_pred_sp_d[sample_id][mask]
        y_true = dists_true[mask]

        result = np.full((y_true.shape[0],), 1.0)
        result[mask] = paired_cosine_distances(y_pred, y_true)

        ctps[split][sample_id] = result.mean()



In [86]:
scores_d["mouse_cortex"] = pd.concat({k: pd.Series(v) for k, v in ctps.items()}).groupby(level=0).mean() #.reindex(splits)

In [87]:
scores_d["mouse_cortex"]

test     0.121183
train    0.534351
val      0.277304
dtype: float64

In [88]:
scores_d

OrderedDict([('dlpfc',
              train    0.553414
              val      0.482188
              test     0.497089
              dtype: float64),
             ('pdac',
              train    0.5
              dtype: float64),
             ('mouse_cortex',
              test     0.121183
              train    0.534351
              val      0.277304
              dtype: float64)])