In [1]:
%load_ext autoreload
%autoreload 2

import abc
import warnings
from collections import OrderedDict
import os
from typing import Dict, List, Optional, Sequence, Tuple, Union

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import squidpy as sq
from anndata import AnnData, read_h5ad
from matplotlib.ticker import FormatStrFormatter
from matplotlib.tri import Triangulation
from omnipath.interactions import import_intercell_network
from pandas import read_csv, read_excel, DataFrame
from scipy import sparse, stats
from tqdm import tqdm

In [2]:

class GraphTools:
    """GraphTools class."""

    celldata: AnnData
    img_celldata: Dict[str, AnnData]

    def compute_adjacency_matrices(
        self, radius: int, coord_type: str = 'generic', n_rings: int = 1, transform: str = None
    ):
        """Compute adjacency matrix for each image in dataset (uses `squidpy.gr.spatial_neighbors`).

        Parameters
        ----------
        radius : int
            Radius of neighbors for non-grid data.
        coord_type : str
            Type of coordinate system.
        n_rings : int
            Number of rings of neighbors for grid data.
        transform : str
            Type of adjacency matrix transform. Valid options are:

            - `spectral` - spectral transformation of the adjacency matrix.
            - `cosine` - cosine transformation of the adjacency matrix.
            - `None` - no transformation of the adjacency matrix.
        """
        pbar_total = len(self.img_celldata.keys())
        with tqdm(total=pbar_total) as pbar:
            for _k, adata in self.img_celldata.items():
                if coord_type == 'grid':
                    radius = None
                else:
                    n_rings = 1
                sq.gr.spatial_neighbors(
                    adata=adata,
                    coord_type=coord_type,
                    radius=radius,
                    n_rings=n_rings,
                    transform=transform,
                    key_added="adjacency_matrix"
                )
                #print(adata.obsp['adjacency_matrix_connectivities'].sum(axis=1).mean())
                pbar.update(1)

    @staticmethod
    def _transform_a(a):
        """Compute degree transformation of adjacency matrix.

        Computes D^(-1) * (A+I), with A an adjacency matrix, I the identity matrix and D the degree matrix.

        Parameters
        ----------
        a
            sparse adjacency matrix.

        Returns
        -------
        degree transformed sparse adjacency matrix
        """
        warnings.filterwarnings("ignore", message="divide by zero encountered in true_divide")
        degrees = 1 / a.sum(axis=0)
        degrees[a.sum(axis=0) == 0] = 0
        degrees = np.squeeze(np.asarray(degrees))
        deg_matrix = sparse.diags(degrees)
        a_out = deg_matrix * a
        return a_out

    def _transform_all_a(self, a_dict: dict):
        """Compute degree transformation for dictionary of adjacency matrices.

        Computes D^(-1) * (A+I), with A an adjacency matrix, I the identity matrix and D the degree matrix for all
        matrices in a dictionary.

        Parameters
        ----------
        a_dict : dict
            a_dict

        Returns
        -------
        dictionary of degree transformed sparse adjacency matrices
        """
        a_transformed = {i: self._transform_a(a) for i, a in a_dict.items()}
        return a_transformed

    @staticmethod
    def _compute_distance_matrix(pos_matrix):
        """Compute distance matrix.

        Parameters
        ----------
        pos_matrix
            Position matrix.

        Returns
        -------
        distance matrix
        """
        diff = pos_matrix[:, :, None] - pos_matrix[:, :, None].T
        return (diff * diff).sum(1)

    def _get_degrees(self, max_distances: list):
        """Get dgrees.

        Parameters
        ----------
        max_distances : list
            List of maximal distances.

        Returns
        -------
        degrees
        """
        degs = {}
        degrees = {}
        for i, adata in self.img_celldata.items():
            pm = np.array(adata.obsm["spatial"])
            dist_matrix = self._compute_distance_matrix(pm)
            degs[i] = {dist: np.sum(dist_matrix < dist * dist, axis=0) for dist in max_distances}
        for dist in max_distances:
            degrees[dist] = [deg[dist] for deg in degs.values()]
        return degrees

    def plot_degree_vs_dist(
        self,
        degree_matrices: Optional[list] = None,
        max_distances: Optional[list] = None,
        lateral_resolution: float = 1.0,
        save: Optional[str] = None,
        suffix: str = "_degree_vs_dist.pdf",
        show: bool = True,
        return_axs: bool = False,
    ):
        """Plot degree versus distances.

        Parameters
        ----------
        degree_matrices : list, optional
            List of degree matrices
        max_distances : list, optional
            List of maximal distances.
        lateral_resolution : float
            Lateral resolution
        save : str, optional
            Whether (if not None) and where (path as string given as save) to save plot.
        suffix : str
            Suffix of file name to save to.
        show : bool
            Whether to display plot.
        return_axs : bool
            Whether to return axis objects.

        Returns
        -------
        axis if `return_axs` is True.

        Raises
        ------
        ValueError
            If `degree_matrices` and `max_distances` are `None`.
        """
        if degree_matrices is None:
            if max_distances is None:
                raise ValueError("Provide either distance matrices or distance values!")
            else:
                degree_matrices = self._get_degrees(max_distances)

        plt.ioff()
        fig = plt.figure(figsize=(4, 3))

        mean_degree = []
        distances = []

        for dist, degrees in degree_matrices.items():
            mean_d = [np.mean(degree) for degree in degrees]
            print(np.mean(mean_d))
            mean_degree += mean_d
            distances += [np.int(dist * lateral_resolution)] * len(mean_d)

        sns_data = pd.DataFrame(
            {
                "dist": distances,
                "mean_degree": mean_degree,
            }
        )
        ax = fig.add_subplot(111)
        sns.boxplot(data=sns_data, x="dist", color="steelblue", y="mean_degree", ax=ax)
        ax.set_yscale("log")
        ax.grid(False)
        plt.ylabel("")
        plt.xlabel("")
        plt.xticks(rotation=90)

        # Save, show and return figure.
        plt.tight_layout()
        if save is not None:
            plt.savefig(save + suffix)

        if show:
            plt.show()

        plt.close(fig)
        plt.ion()

        if return_axs:
            return ax
        else:
            return None


class PlottingTools:
    """PlottingTools class."""

    celldata: AnnData
    img_celldata: Dict[str, AnnData]

    def celldata_interaction_matrix(
        self,
        fontsize: Optional[int] = None,
        figsize: Tuple[float, float] = (5, 5),
        title: Optional[str] = None,
        save: Optional[str] = None,
        suffix: str = "_celldata_interaction_matrix.pdf",
    ):
        """Compute and plot interaction matrix of celldata.

        The interaction matrix is computed by `squidpy.gr.interaction_matrix()`.

        Parameters
        ----------
        fontsize : int, optional
            Font size.
        figsize : tuple
            Figure size.
        title : str, optional
            Figure title.
        save : str, optional
            Whether (if not None) and where (path as string given as save) to save plot.
        suffix : str
            Suffix of file name to save to.
        """
        interaction_matrix = []
        cluster_key = self.celldata.uns["metadata"]["cluster_col_preprocessed"]
        with tqdm(total=len(self.img_celldata.keys())) as pbar:
            for adata in self.img_celldata.values():
                im = sq.gr.interaction_matrix(
                    adata, cluster_key=cluster_key, connectivity_key="adjacency_matrix", normalized=False, copy=True
                )
                im = pd.DataFrame(
                    im, columns=list(np.unique(adata.obs[cluster_key])), index=list(np.unique(adata.obs[cluster_key]))
                )
                interaction_matrix.append(im)
                pbar.update(1)
        df_concat = pd.concat(interaction_matrix)
        by_row_index = df_concat.groupby(df_concat.index)
        df_means = by_row_index.sum().sort_index(axis=1)
        interactions = np.array(df_means).T
        self.celldata.uns[f"{cluster_key}_interactions"] = interactions/np.sum(interactions, axis=1)[:, np.newaxis]

        if fontsize:
            sc.set_figure_params(scanpy=True, fontsize=fontsize)
        if save:
            save = save + suffix
        sq.pl.interaction_matrix(
            self.celldata,
            cluster_key=cluster_key,
            connectivity_key="adjacency_matrix",
            figsize=figsize,
            title=title,
            save=save,
        )

    def celldata_nhood_enrichment(
        self,
        fontsize: Optional[int] = None,
        figsize: Tuple[float, float] = (5, 5),
        title: Optional[str] = None,
        save: Optional[str] = None,
        suffix: str = "_celldata_nhood_enrichment.pdf",
    ):
        """Compute and plot neighbourhood enrichment of celldata.

        The enrichment is computed by `squidpy.gr.nhood_enrichment()`.

        Parameters
        ----------
        fontsize : int, optional
            Font size.
        figsize : tuple
            Figure size.
        title : str, optional
            Figure title.
        save : str, optional
            Whether (if not None) and where (path as string given as save) to save plot.
        suffix : str
            Suffix of file name to save to.
        """
        zscores = []
        counts = []
        cluster_key = self.celldata.uns["metadata"]["cluster_col_preprocessed"]
        with tqdm(total=len(self.img_celldata.keys())) as pbar:
            for adata in self.img_celldata.values():
                im = sq.gr.nhood_enrichment(
                    adata,
                    cluster_key=cluster_key,
                    connectivity_key="adjacency_matrix",
                    copy=True,
                    show_progress_bar=False,
                )
                zscore = pd.DataFrame(
                    im[0],
                    columns=list(np.unique(adata.obs[cluster_key])),
                    index=list(np.unique(adata.obs[cluster_key])),
                )
                count = pd.DataFrame(
                    im[1],
                    columns=list(np.unique(adata.obs[cluster_key])),
                    index=list(np.unique(adata.obs[cluster_key])),
                )
                zscores.append(zscore)
                counts.append(count)
                pbar.update(1)
        df_zscores = pd.concat(zscores)
        by_row_index = df_zscores.groupby(df_zscores.index)
        df_zscores = by_row_index.mean().sort_index(axis=1)

        df_counts = pd.concat(counts)
        by_row_index = df_counts.groupby(df_counts.index)
        df_counts = by_row_index.sum().sort_index(axis=1)

        self.celldata.uns[f"{cluster_key}_nhood_enrichment"] = {
            "zscore": np.array(df_zscores).T,
            "count": np.array(df_counts).T,
        }
        if fontsize:
            sc.set_figure_params(scanpy=True, fontsize=fontsize)
        if save:
            save = save + suffix
        sq.pl.nhood_enrichment(
            self.celldata,
            cluster_key=cluster_key,
            connectivity_key="adjacency_matrix",
            figsize=figsize,
            title=title,
            save=save,
        )

    def celltype_frequencies(
        self,
        figsize: Tuple[float, float] = (5.0, 6.0),
        fontsize: Optional[int] = None,
        save: Optional[str] = None,
        suffix: str = "_noise_structure.pdf",
        show: bool = True,
        return_axs: bool = False,
    ):
        """Plot cell type frequencies from celldata on the complete dataset.

        Parameters
        ----------
        fontsize : int, optional
           Font size.
        figsize : tuple
           Figure size.
        save : str, optional
            Whether (if not None) and where (path as string given as save) to save plot.
        suffix : str
            Suffix of file name to save to.
        show : bool
            Whether to display plot.
        return_axs : bool
            Whether to return axis objects.

        Returns
        -------
        axis
            If `return_axs` is True.
        """
        plt.ioff()
        cluster_id = self.celldata.uns["metadata"]["cluster_col_preprocessed"]
        if fontsize:
            sc.set_figure_params(scanpy=True, fontsize=fontsize)

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
        sns.barplot(
            y=self.celldata.obs[cluster_id].value_counts().index,
            x=list(self.celldata.obs[cluster_id].value_counts()),
            color="steelblue",
            ax=ax,
        )
        ax.grid(False)
        # Save, show and return figure.
        plt.tight_layout()
        if save is not None:
            plt.savefig(save + suffix)

        if show:
            plt.show()

        plt.close(fig)
        plt.ion()

        if return_axs:
            return ax
        else:
            return None

    def noise_structure(
        self,
        undefined_type: Optional[str] = None,
        merge_types: Optional[Tuple[list, list]] = None,
        min_x: Optional[float] = None,
        max_x: Optional[float] = None,
        panelsize: Tuple[float, float] = (2.0, 2.3),
        fontsize: Optional[int] = None,
        save: Optional[str] = None,
        suffix: str = "_noise_structure.pdf",
        show: bool = True,
        return_axs: bool = False,
    ):
        """Plot cell type frequencies grouped by cell type.

        Parameters
        ----------
        undefined_type : str, optional
            Undefined cell type.
        merge_types : tuple, optional
            Merge cell types.
        min_x : float, optional
            Minimal x value.
        max_x : float, optional
            Maximal x value.
        fontsize : int, optional
           Font size.
        panelsize : tuple
           Panel size.
        save : str, optional
            Whether (if not None) and where (path as string given as save) to save plot.
        suffix : str
            Suffix of file name to save to.
        show : bool
            Whether to display plot.
        return_axs : bool
            Whether to return axis objects.

        Returns
        -------
        axis
            If `return_axs` is True.
        """
        if fontsize:
            sc.set_figure_params(scanpy=True, fontsize=fontsize)
        feature_mat = pd.concat(
            [
                pd.concat(
                    [
                        pd.DataFrame(
                            {
                                "image": [k for _i in range(adata.shape[0])],
                            }
                        ),
                        pd.DataFrame(adata.X, columns=list(adata.var_names)),
                        pd.DataFrame(
                            np.asarray(list(adata.uns["node_type_names"].values()))[
                                np.argmax(adata.obsm["node_types"], axis=1)
                            ],
                            columns=["cell_type"],
                        ),
                    ],
                    axis=1,
                ).melt(value_name="expression", var_name="gene", id_vars=["cell_type", "image"])
                for k, adata in self.img_celldata.items()
            ]
        )
        feature_mat["log_expression"] = np.log(feature_mat["expression"].values + 1)
        if undefined_type:
            feature_mat = feature_mat[feature_mat["cell_type"] != undefined_type]

        if merge_types:
            for mt in merge_types[0]:
                feature_mat = feature_mat.replace(mt, merge_types[-1])

        plt.ioff()
        ct = np.unique(feature_mat["cell_type"].values)
        nrows = len(ct) // 12 + int(len(ct) % 12 > 0)
        fig, ax = plt.subplots(
            ncols=12, nrows=nrows, figsize=(12 * panelsize[0], nrows * panelsize[1]), sharex="all", sharey="all"
        )
        ax = ax.flat
        for axis in ax[len(ct) :]:
            axis.remove()
        for i, ci in enumerate(ct):
            tab = feature_mat.loc[feature_mat["cell_type"].values == ci, :]
            x = np.log(tab.groupby(["gene"])["expression"].mean() + 1)
            y = np.log(tab.groupby(["gene"])["expression"].var() + 1)
            sns.scatterplot(x=x, y=y, ax=ax[i])
            min_x = np.min(x) if min_x is None else min_x
            max_x = np.max(x) if max_x is None else max_x
            sns.lineplot(x=[min_x, max_x], y=[2 * min_x, 2 * max_x], color="black", ax=ax[i])
            ax[i].grid(False)
            ax[i].set_title(ci, fontsize=fontsize)
            ax[i].set_xlabel("")
            ax[i].set_ylabel("")
            ax[i].yaxis.set_major_formatter(FormatStrFormatter("%0.1f"))
        # Save, show and return figure.
        plt.tight_layout()
        if save is not None:
            plt.savefig(save + suffix)

        if show:
            plt.show()

        plt.close(fig)
        plt.ion()

        if return_axs:
            return ax
        else:
            return None

In [3]:

class DataLoader(GraphTools, PlottingTools):
    """DataLoader class. Inherits all functions from GraphTools and PlottingTools."""

    def __init__(
        self,
        data_path: str,
        radius: Optional[int] = None,
        coord_type: str = 'generic',
        n_rings: int = 1,
        label_selection: Optional[List[str]] = None,
        n_top_genes: Optional[int] = None
    ):
        """Initialize DataLoader.

        Parameters
        ----------
        data_path : str
            Data path.
        radius : int
            Radius.
        label_selection : list, optional
            label selection.
        """
        self.data_path = data_path

        print("Loading data from raw files")
        self.register_celldata(n_top_genes=n_top_genes)
        self.register_img_celldata()
        self.register_graph_features(label_selection=label_selection)
        self.compute_adjacency_matrices(radius=radius, coord_type=coord_type, n_rings=n_rings)
        self.radius = radius

        print(
            "Loaded %i images with complete data from %i patients "
            "over %i cells with %i cell features and %i distinct celltypes."
            % (
                len(self.img_celldata),
                len(self.patients),
                self.celldata.shape[0],
                self.celldata.shape[1],
                len(self.celldata.uns["node_type_names"]),
            )
        )

    @property
    def patients(self):
        """Return number of patients in celldata.

        Returns
        -------
        patients
        """
        return np.unique(np.asarray(list(self.celldata.uns["img_to_patient_dict"].values())))

    def register_celldata(self, n_top_genes: Optional[int] = None):
        """Load AnnData object of complete dataset."""
        print("registering celldata")
        self._register_celldata(n_top_genes=n_top_genes)
        assert self.celldata is not None, "celldata was not loaded"

    def register_img_celldata(self):
        """Load dictionary of of image-wise celldata objects with {imgage key : anndata object of image}."""
        print("collecting image-wise celldata")
        self._register_img_celldata()
        assert self.img_celldata is not None, "image-wise celldata was not loaded"

    def register_graph_features(self, label_selection):
        """Load graph level covariates.

        Parameters
        ----------
        label_selection
            Label selection.
        """
        print("adding graph-level covariates")
        self._register_graph_features(label_selection=label_selection)

    @abc.abstractmethod
    def _register_celldata(self, n_top_genes: Optional[int] = None):
        """Load AnnData object of complete dataset."""
        pass

    @abc.abstractmethod
    def _register_img_celldata(self):
        """Load dictionary of of image-wise celldata objects with {imgage key : anndata object of image}."""
        pass

    @abc.abstractmethod
    def _register_graph_features(self, label_selection):
        """Load graph level covariates.

        Parameters
        ----------
        label_selection
            Label selection.
        """
        pass

    def size_factors(self):
        """Get size factors (Only makes sense with positive input).

        Returns
        -------
        sf_dict
        """
        # Check if irregular sums are encountered:
        for i, adata in self.img_celldata.items():
            if np.any(np.sum(adata.X, axis=1) <= 0):
                print("WARNING: found irregular node sizes in image %s" % str(i))
        # Get global mean of feature intensity across all features:
        global_mean_per_node = self.celldata.X.sum(axis=1).mean(axis=0)
        return {i: global_mean_per_node / np.sum(adata.X, axis=1) for i, adata in self.img_celldata.items()}

    @property
    def var_names(self):
        return self.celldata.var_names

In [6]:

class DataLoaderHartmann(DataLoader):
    """DataLoaderHartmann class. Inherits all functions from DataLoader."""

    cell_type_merge_dict = {
        "Imm_other": "Other immune cells",
        "Epithelial": "Epithelial",
        "Tcell_CD4": "CD4 T cells",
        "Myeloid_CD68": "CD68 Myeloid",
        "Fibroblast": "Fibroblast",
        "Tcell_CD8": "CD8 T cells",
        "Endothelial": "Endothelial",
        "Myeloid_CD11c": "CD11c Myeloid",
    }

    def _register_celldata(self, n_top_genes: Optional[int] = None):
        """Load AnnData object of complete dataset."""
        metadata = {
            "lateral_resolution": 400 / 1024,
            "fn": ["scMEP_MIBI_singlecell/scMEP_MIBI_singlecell.csv", "scMEP_sample_description.xlsx"],
            "image_col": "point",
            "pos_cols": ["center_colcoord", "center_rowcoord"],
            "cluster_col": "Cluster",
            "cluster_col_preprocessed": "Cluster_preprocessed",
            "patient_col": "donor",
        }
        celldata_df = read_csv(os.path.join(self.data_path, metadata["fn"][0]))
        celldata_df["point"] = [f"scMEP_point_{str(x)}" for x in celldata_df["point"]]
        celldata_df = celldata_df.fillna(0)
        # celldata_df = celldata_df.dropna(inplace=False).reset_index()
        feature_cols = [
            "H3",
            "vimentin",
            "SMA",
            "CD98",
            "NRF2p",
            "CD4",
            "CD14",
            "CD45",
            "PD1",
            "CD31",
            "SDHA",
            "Ki67",
            "CS",
            "S6p",
            "CD11c",
            "CD68",
            "CD36",
            "ATP5A",
            "CD3",
            "CD39",
            "VDAC1",
            "G6PD",
            "XBP1",
            "PKM2",
            "ASCT2",
            "GLUT1",
            "CD8",
            "CD57",
            "LDHA",
            "IDH2",
            "HK1",
            "Ecad",
            "CPT1A",
            "CK",
            "NaKATPase",
            "HIF1A",
            # "X1",
            # "cell_size",
            # "category",
            # "donor",
            # "Cluster",
        ]
        var_names = [
            'H3-4', 
            'VIM', 
            'SMN1', 
            'SLC3A2', 
            'NFE2L2', 
            'CD4', 
            'CD14', 
            'PTPRC', 
            'PDCD1',
            'PECAM1', 
            'SDHA', 
            'MKI67', 
            'CS', 
            'RPS6', 
            'ITGAX', 
            'CD68', 
            'CD36', 
            'ATP5F1A',
            'CD247', 
            'ENTPD1', 
            'VDAC1', 
            'G6PD', 
            'XBP1', 
            'PKM', 
            'SLC1A5', 
            'SLC2A1', 
            'CD8A',
            'B3GAT1', 
            'LDHA', 
            'IDH2', 
            'HK1', 
            'CDH1', 
            'CPT1A', 
            'CKM', 
            'ATP1A1',
            'HIF1A'
        ]

        celldata = AnnData(
            X=pd.DataFrame(np.array(celldata_df[feature_cols]), columns=var_names), obs=celldata_df[
                ["point", "cell_id", "cell_size", "donor", "Cluster"]
            ].astype("category"),
        )

        celldata.uns["metadata"] = metadata
        img_keys = list(np.unique(celldata_df[metadata["image_col"]]))
        celldata.uns["img_keys"] = img_keys

        # register x and y coordinates into obsm
        celldata.obsm["spatial"] = np.array(celldata_df[metadata["pos_cols"]])

        img_to_patient_dict = {
            str(x): celldata_df[metadata["patient_col"]].values[i]
            for i, x in enumerate(celldata_df[metadata["image_col"]].values)
        }
        # img_to_patient_dict = {k: "p_1" for k in img_keys}
        celldata.uns["img_to_patient_dict"] = img_to_patient_dict
        self.img_to_patient_dict = img_to_patient_dict

        # add clean cluster column which removes regular expression from cluster_col
        celldata.obs[metadata["cluster_col_preprocessed"]] = list(
            pd.Series(list(celldata.obs[metadata["cluster_col"]]), dtype="category").map(self.cell_type_merge_dict)
        )
        celldata.obs[metadata["cluster_col_preprocessed"]] = celldata.obs[metadata["cluster_col_preprocessed"]].astype(
            "category"
        )

        # register node type names
        node_type_names = list(np.unique(celldata.obs[metadata["cluster_col_preprocessed"]]))
        celldata.uns["node_type_names"] = {x: x for x in node_type_names}
        node_types = np.zeros((celldata.shape[0], len(node_type_names)))
        node_type_idx = np.array(
            [
                node_type_names.index(x) for x in celldata.obs[metadata["cluster_col_preprocessed"]].values
            ]  # index in encoding vector
        )
        node_types[np.arange(0, node_type_idx.shape[0]), node_type_idx] = 1
        celldata.obsm["node_types"] = node_types

        self.celldata = celldata

    def _register_img_celldata(self):
        """Load dictionary of of image-wise celldata objects with {imgage key : anndata object of image}."""
        image_col = self.celldata.uns["metadata"]["image_col"]
        img_celldata = {}
        for k in self.celldata.uns["img_keys"]:
            img_celldata[str(k)] = self.celldata[self.celldata.obs[image_col] == k].copy()
        self.img_celldata = img_celldata

    def _register_graph_features(self, label_selection):
        """Load graph level covariates.

        Parameters
        ----------
        label_selection
            Label selection.
        """
        # DEFINE COLUMN NAMES FOR TABULAR DATA.
        # Define column names to extract from patient-wise tabular data:
        patient_col = "ID"
        # These are required to assign the image to dieased and non-diseased:
        disease_features = {"Diagnosis": "categorical"}
        patient_features = {"ID": "categorical", "Age": "continuous", "Sex": "categorical"}
        label_cols = {}
        label_cols.update(disease_features)
        label_cols.update(patient_features)

        if label_selection is None:
            label_selection = set(label_cols.keys())
        else:
            label_selection = set(label_selection)
        label_cols_toread = list(label_selection.intersection(set(list(label_cols.keys()))))
        usecols = label_cols_toread + [patient_col]

        tissue_meta_data = read_excel(os.path.join(self.data_path, "scMEP_sample_description.xlsx"), usecols=usecols)
        # BUILD LABEL VECTORS FROM LABEL COLUMNS
        # The columns contain unprocessed numeric and categorical entries that are now processed to prediction-ready
        # numeric tensors. Here we first generate a dictionary of tensors for each label (label_tensors). We then
        # transform this to have as output of this section dictionary by image with a dictionary by labels as values
        # which can be easily queried by image in a data generator.
        # Subset labels and label types:
        label_cols = {label: nt for label, nt in label_cols.items() if label in label_selection}
        label_tensors = {}
        label_names = {}  # Names of individual variables in each label vector (eg. categories in onehot-encoding).
        # 1. Standardize continuous labels to z-scores:
        continuous_mean = {
            feature: tissue_meta_data[feature].mean(skipna=True)
            for feature in list(label_cols.keys())
            if label_cols[feature] == "continuous"
        }
        continuous_std = {
            feature: tissue_meta_data[feature].std(skipna=True)
            for feature in list(label_cols.keys())
            if label_cols[feature] == "continuous"
        }
        for feature in list(label_cols.keys()):
            if label_cols[feature] == "continuous":
                label_tensors[feature] = (tissue_meta_data[feature].values - continuous_mean[feature]) / continuous_std[
                    feature
                ]
                label_names[feature] = [feature]
        # 2. One-hot encode categorical columns
        # Force all entries in categorical columns to be string so that GLM-like formula processing can be performed.
        for feature in list(label_cols.keys()):
            if label_cols[feature] == "categorical":
                tissue_meta_data[feature] = tissue_meta_data[feature].astype("str")
        # One-hot encode each string label vector:
        for i, feature in enumerate(list(label_cols.keys())):
            if label_cols[feature] == "categorical":
                oh = pd.get_dummies(tissue_meta_data[feature], prefix=feature, prefix_sep=">", drop_first=False)
                # Change all entries of corresponding observation to np.nan instead.
                idx_nan_col = np.array([i for i, x in enumerate(oh.columns) if x.endswith(">nan")])
                if len(idx_nan_col) > 0:
                    assert len(idx_nan_col) == 1, "fatal processing error"
                    nan_rows = np.where(oh.iloc[:, idx_nan_col[0]].values == 1.0)[0]
                    oh.loc[nan_rows, :] = np.nan
                # Drop nan element column.
                oh = oh.loc[:, [x for x in oh.columns if not x.endswith(">nan")]]
                label_tensors[feature] = oh.values
                label_names[feature] = oh.columns
        # Make sure all tensors are 2D for indexing:
        for feature in list(label_tensors.keys()):
            if len(label_tensors[feature].shape) == 1:
                label_tensors[feature] = np.expand_dims(label_tensors[feature], axis=1)
        # The dictionary of tensor is nested in slices in a dictionary by image which is easier to query with a
        # generator.
        tissue_meta_data_patients = tissue_meta_data[patient_col].values.tolist()
        label_tensors = {
            img: {
                feature_name: np.array(features[tissue_meta_data_patients.index(patient), :], ndmin=1)
                for feature_name, features in label_tensors.items()
            }
            if patient in tissue_meta_data_patients
            else None
            for img, patient in self.celldata.uns["img_to_patient_dict"].items()
        }
        # Reduce to observed patients:
        label_tensors = dict([(k, v) for k, v in label_tensors.items() if v is not None])

        # Save processed data to attributes.
        for k, adata in self.img_celldata.items():
            graph_covariates = {
                "label_names": label_names,
                "label_tensors": label_tensors[k],
                "label_selection": list(label_cols.keys()),
                "continuous_mean": continuous_mean,
                "continuous_std": continuous_std,
                "label_data_types": label_cols,
            }
            adata.uns["graph_covariates"] = graph_covariates

        graph_covariates = {
            "label_names": label_names,
            "label_selection": list(label_cols.keys()),
            "continuous_mean": continuous_mean,
            "continuous_std": continuous_std,
            "label_data_types": label_cols,
        }
        self.celldata.uns["graph_covariates"] = graph_covariates

        # self.ref_img_keys = {k: [] for k, v in self.nodes_by_image.items()}

In [7]:
data = DataLoaderHartmann(data_path='data')

Loading data from raw files
registering celldata


  celldata = AnnData(


collecting image-wise celldata
adding graph-level covariates


100%|██████████| 58/58 [00:01<00:00, 44.21it/s]

Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.





In [None]:
import torch
from torch_geometric.data import InMemoryDataset, download_url


class HartmannDataset(InMemoryDataset):
    def __init__(self, root='data', transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        data = DataLoaderHartmann(data_path=root)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [11]:
data.img_celldata

{'scMEP_point_1': AnnData object with n_obs × n_vars = 1338 × 36
     obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
     uns: 'metadata', 'img_keys', 'img_to_patient_dict', 'node_type_names', 'graph_covariates', 'adjacency_matrix_neighbors'
     obsm: 'spatial', 'node_types'
     obsp: 'adjacency_matrix_connectivities', 'adjacency_matrix_distances',
 'scMEP_point_10': AnnData object with n_obs × n_vars = 61 × 36
     obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
     uns: 'metadata', 'img_keys', 'img_to_patient_dict', 'node_type_names', 'graph_covariates', 'adjacency_matrix_neighbors'
     obsm: 'spatial', 'node_types'
     obsp: 'adjacency_matrix_connectivities', 'adjacency_matrix_distances',
 'scMEP_point_11': AnnData object with n_obs × n_vars = 1316 × 36
     obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
     uns: 'metadata', 'img_keys', 'img_to_patient_dict', 'node_type_names',