In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import squidpy as sq
import torch
import pandas as pd
import numpy as np
from typing import Sequence, Union
from torch_geometric.data import Data
from anndata import AnnData
import scipy
import torch.nn as nn




In [5]:
def linear_ncem_adata2data_helper(
    adatas: Sequence[AnnData], feature_names
) -> Sequence[Data]:
    dataset = []
    for adata in adatas:
        # Get adjacency matrices
        spatial_connectivities = _get_adj_matrix(adata)

        # Get features (there is one case since we need the domain for Xd)
        df1 = pd.get_dummies(adata.obs[feature_names[0]])
        df2 = pd.DataFrame(
            0,
            index=np.arange(len(adata.obs)),
            columns=set(adata.uns["node_type_names"].values()),
        )
        df2[df1.columns[0]] = list(df1[df1.columns[0]])
        cell_type = torch.from_numpy(df2.to_numpy())

        df1 = pd.get_dummies(
            adata.obs[feature_names[1]],
            columns=set(adata.uns["img_to_patient_dict"].values()),
        )
        df2 = pd.DataFrame(
            0,
            index=np.arange(len(adata.obs)),
            columns=set(adata.uns["img_to_patient_dict"].values()),
        )
        df2[df1.columns[0]] = list(df1[df1.columns[0]])
        domain = torch.from_numpy(df2.to_numpy())

        # Get gene expression matrix
        gene_expression = torch.from_numpy(adata.X.astype(float))
        # Create design matrix for linear model
        Xd = design_matrix(
            torch.Tensor(spatial_connectivities.todense()), cell_type.float(), domain
        )

        data = Data(y=gene_expression, x=Xd)
        dataset.append(data)

    return dataset

In [6]:
def adata2data(
    adatas: Union[AnnData, Sequence[AnnData]], feature_names
) -> Union[Data, Sequence[Data]]:
    """Function that takes in input a sequence of anndata objects and returns a Pytorch Geometric (PyG) data object or a sequence thereof.
    Each data object represents a graph of an image stored in the anndata object.


    :param adata: Anndata object storing the images to be trained on
    :type adata: AnnData
    :param feature_names: The feature names to be used for training, extracted from anndata.obs
    :type feature_names: tuple
    :return: PyG data object or sequence thereof if more than one image is stored in the anndata object
    :rtype: Union[Data, Sequence[Data]]
    """

    dataset = []

    for adata in adatas:

        # Get adjacency matrices
        spatial_connectivities = _get_adj_matrix(adata)

        nodes1, nodes2 = spatial_connectivities.nonzero()
        edge_index = torch.vstack(
            [
                torch.from_numpy(nodes1).to(torch.long),
                torch.from_numpy(nodes2).to(torch.long),
            ]
        )

        # Get features
        if len(feature_names) > 1:
            df1 = pd.get_dummies(adata.obs[feature_names[0]])
            df2 = pd.DataFrame(
                0,
                index=np.arange(len(adata.obs)),
                columns=set(adata.uns["node_type_names"].values()),
            )
            df2[df1.columns[0]] = list(df1[df1.columns[0]])
            cell_type = torch.from_numpy(df2.to_numpy())

            df1 = pd.get_dummies(
                adata.obs[feature_names[1]],
                columns=set(adata.uns["img_to_patient_dict"].values()),
            )
            df2 = pd.DataFrame(
                0,
                index=np.arange(len(adata.obs)),
                columns=set(adata.uns["img_to_patient_dict"].values()),
            )
            df2[df1.columns[0]] = list(df1[df1.columns[0]])
            domain = torch.from_numpy(df2.to_numpy())

            features_combined = torch.cat([cell_type, domain], dim=1)
        else:
            features_combined = torch.from_numpy(
                pd.get_dummies(adata.obs[feature_names]).to_numpy()
            )

        # Get gene expression matrix
        gene_expression = torch.from_numpy(adata.X.astype(float))
        # Create design matrix for linear model
        Xd = design_matrix(
            torch.Tensor(spatial_connectivities.todense()), cell_type.float(), domain
        )

        data = Data(
            edge_index=edge_index, y=gene_expression, x=features_combined, Xd=Xd
        )
        dataset.append(data)

    return dataset

In [7]:
def adata2data_sq(
    adatas: Union[AnnData, Sequence[AnnData]], feature_names
) -> Union[Data, Sequence[Data]]:
    """Function that takes in input an anndata object from a squidpy example dataset and returns a Pytorch Geometric (PyG) data object or a sequence thereof.
    Each data object represents a graph of an image stored in the anndata object.


    :param adata: Anndata object storing the images to be trained on
    :type adata: AnnData
    :param feature_names: The feature names to be used for training, extracted from anndata.obs
    :type feature_names: tuple
    :return: PyG data object or sequence thereof if more than one image is stored in the anndata object
    :rtype: Union[Data, Sequence[Data]]
    """

    dataset = []

    # if isinstance(adata, list):
    #     adatas=adata
    # else:
    #     adatas = [adata]

    for adata in adatas:
        # Set cases for when one or more images are to be extracted from anndata

        # Case where multiply images are stored in one anndata
        if "library_id" in adata.obs.keys():
            library_ids = [library_id for library_id in adata.uns["spatial"].keys()]
            lib_indices = adata.obs["library_id"] == library_ids[0]

            for i in range(len(library_ids) - 1):
                lib_indices = pd.concat(
                    [lib_indices, adata.obs["library_id"] == library_ids[i + 1]], axis=1
                )
            lib_indices.columns = library_ids

        # Case where one image is stored in one anndata
        else:
            lib_indices = pd.DataFrame(data=range(len(adata.obs)), columns=[""])

        for library_id in lib_indices.columns:

            # Get adjacency matrices
            if "adjacency_matrix_connectivities" in adata.obsp.keys():
                spatial_connectivities = adata.obsp["adjacency_matrix_connectivities"]

            else:
                spatial_connectivities, _ = sq.gr.spatial_neighbors(
                    adata[lib_indices[library_id]],
                    coord_type="generic",
                    key_added=library_id + "spatial",
                    copy=True,
                )

            nodes1, nodes2 = spatial_connectivities.nonzero()
            edge_index = torch.vstack(
                [
                    torch.from_numpy(nodes1).to(torch.long),
                    torch.from_numpy(nodes2).to(torch.long),
                ]
            )

            # Get features
            if len(feature_names) > 1:
                cell_type = torch.from_numpy(
                    pd.get_dummies(
                        adata.obs[feature_names[0]][lib_indices[library_id]]
                    ).to_numpy()
                )
                domain = torch.from_numpy(
                    pd.get_dummies(
                        adata.obs[feature_names[1]][lib_indices[library_id]]
                    ).to_numpy()
                )
                features_combined = torch.cat([cell_type, domain], dim=1)
            else:
                features_combined = torch.from_numpy(
                    pd.get_dummies(adata.obs[feature_names]).to_numpy()
                )

            # Get gene expression matrix
            X = adata.X[lib_indices[library_id]]

            if scipy.sparse.issparse(X):
                coo = X.tocoo()
                values = coo.data
                indices = np.vstack((coo.row, coo.col))
                i = torch.LongTensor(indices)
                v = torch.FloatTensor(values)
                shape = coo.shape
                gene_expression = torch.sparse.FloatTensor(
                    i, v, torch.Size(shape)
                ).to_dense()
            else:
                gene_expression = torch.from_numpy(adata.X)

            # Create design matrix for linear model
            Xd = design_matrix(
                torch.Tensor(spatial_connectivities.todense()),
                cell_type.float(),
                domain,
            )

            data = Data(
                edge_index=edge_index, y=gene_expression, x=features_combined, Xd=Xd
            )
            dataset.append(data)

    return dataset



def design_matrix(A, Xl, Xc):
    N, L = Xl.shape
    Xs = (A @ Xl > 0).to(torch.float)  # N x L
    Xts = (torch.einsum("bp,br->bpr", Xs, Xl).reshape((N, L * L)) > 0).to(torch.float)
    Xd = torch.hstack((Xl, Xts, Xc))
    return Xd

In [61]:
class AnnData2DataCallable:
    def __init__(self, is_sq=False, has_edge_index=True, x_names,y_names):
        self.is_sq = is_sq
        self.has_edge_index = has_edge_index
        if is_sq:
            self._adata_iter = AnnData2DataCallable._get_sq_adata_iter
        else:
            self._adata_iter = lambda x: x # identity

        self.xs = x_names
        self.ys = y_names
    @staticmethod
    def _get_sq_adata_iter(adata):
        cats = adata.obs.library_id.dtypes.categories
        for cat in cats:
            yield adata[adata.obs.library_id == cat]
                    
    @staticmethod
    def _get_adj_matrix(adata):
        """helper function to create adj matrix depending on the adata"""
    
        # Get adjacency matrices
        if "adjacency_matrix_connectivities" in adata.obsp.keys():
            spatial_connectivities = adata.obsp["adjacency_matrix_connectivities"]

        else:
            spatial_connectivities, _ = sq.gr.spatial_neighbors(
                adata,
                coord_type="generic",
                key_added="spatial",
                copy=True,
            )
        return spatial_connectivities

    def _create_data_obj(self,adata,spatial_connectivities):
        obj = dict()
        if self.has_edge_index:
            nodes1, nodes2 = spatial_connectivities.nonzero()
            obj["edge_index"] = torch.vstack(
                [
                    torch.from_numpy(nodes1).to(torch.long),
                    torch.from_numpy(nodes2).to(torch.long),
                ]
            )
        
        features_dict = self._extract_features(adata, self.
    
    def __call__(self, adatas):
        dataset = []
        
        for adata in self._adata_iter(adatas):
            
            spatial_connectivities = self._get_adj_matrix(adata)
            
            data = self._create_data_obj(adata, spatial_connectivities)
            
            
            # Get features
            
            
            

        return dataset

In [62]:


#Mibitof
adata = sq.datasets.mibitof()
#feature_name=adata.obs.keys()[0] #Use for IMC dataset

#specify features to use
feature_names=['Cluster','batch']

def mibitof2data(adata):
    return adata2data(adata,feature_names)


#input of datamodule
num_features=(len(set(adata.obs[feature_names[0]])),len(set(adata.obs[feature_names[1]])))

num_genes=adata.X.shape[1]



In [63]:
AnnData2DataCallable(is_sq=True)(adata)

KeyError: 'node_type_names'

In [83]:
adata = sq.datasets.mibitof()

In [84]:
adata.obs['library_id']

3034-0     point23
3035-0     point23
3036-0     point23
3037-0     point23
3038-0     point23
            ...   
47342-2    point16
47343-2    point16
47344-2    point16
47345-2    point16
47346-2    point16
Name: library_id, Length: 3309, dtype: category
Categories (3, object): ['point16', 'point23', 'point8']

In [89]:
adata[adata.obs.library_id == "point16"].obs.library_id == adata.obs.library_id.dtypes.categories

ValueError: ('Lengths must match to compare', (1023,), (3,))

In [33]:
adata.obs.library_id.dtypes.categories

Index(['point16', 'point23', 'point8'], dtype='object')

In [67]:
from gpu_spatial_graph_pipeline.data.datasets import DatasetHartmann

In [68]:
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adata = list(dataset.img_celldata.values())

Loading data from raw files
registering celldata




collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.


In [70]:
adata = adata[0]

In [82]:
adata.uns["img_keys"]

['scMEP_point_1',
 'scMEP_point_10',
 'scMEP_point_11',
 'scMEP_point_12',
 'scMEP_point_13',
 'scMEP_point_14',
 'scMEP_point_15',
 'scMEP_point_16',
 'scMEP_point_17',
 'scMEP_point_18',
 'scMEP_point_19',
 'scMEP_point_2',
 'scMEP_point_20',
 'scMEP_point_21',
 'scMEP_point_22',
 'scMEP_point_23',
 'scMEP_point_24',
 'scMEP_point_25',
 'scMEP_point_26',
 'scMEP_point_27',
 'scMEP_point_28',
 'scMEP_point_29',
 'scMEP_point_3',
 'scMEP_point_30',
 'scMEP_point_31',
 'scMEP_point_32',
 'scMEP_point_33',
 'scMEP_point_34',
 'scMEP_point_35',
 'scMEP_point_36',
 'scMEP_point_37',
 'scMEP_point_38',
 'scMEP_point_39',
 'scMEP_point_4',
 'scMEP_point_40',
 'scMEP_point_41',
 'scMEP_point_42',
 'scMEP_point_43',
 'scMEP_point_44',
 'scMEP_point_45',
 'scMEP_point_46',
 'scMEP_point_47',
 'scMEP_point_48',
 'scMEP_point_49',
 'scMEP_point_5',
 'scMEP_point_50',
 'scMEP_point_51',
 'scMEP_point_52',
 'scMEP_point_53',
 'scMEP_point_54',
 'scMEP_point_55',
 'scMEP_point_56',
 'scMEP_point_57'

In [76]:
pd.get_dummies(adata.obs.point)

Unnamed: 0,scMEP_point_1
59191,1
59192,1
59193,1
59194,1
59195,1
...,...
60524,1
60525,1
60526,1
60527,1
