# Creation of torch-geometric data objects from AnnData

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
import anndata as ad
import squidpy as sq
from geome import transforms
from geome.adata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann

warnings.filterwarnings('ignore')



	geopandas.options.use_pygeos = True

If you intended to use PyGEOS, set the option to False.
  _check_geopandas_using_shapely()


## All NCEM Datasets

### Load Unprocessed Dataset

In [3]:
# Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())

Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.


### Some Preprocessing Done Manually

These processing steps can also be done in the a2c callable if they are given as functions in to the preprocess list

In [4]:
# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)

In [5]:
adata

AnnData object with n_obs × n_vars = 63747 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'

### Creating A2D


In [6]:
fields = {
    'features':['obs/Cluster_preprocessed','obs/donor','obsm/design_matrix'],
    'labels':['X']
}

Here we list the preprocessing steps that we need to be done on anndata. They take two parameters the `adata` and the `fields`. But we don't use the `fields` in this example.

In [7]:
from geome.transforms import Compose, Categorize, AddDesignMatrix, AddAdjMatrix


# [
#     lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
#     lambda x,_: transforms.add_design_matrix(x,'obs/Cluster_preprocessed','obs/donor','design_matrix'),
# ]
adj_matrix_loc = 'obsp/adjacency_matrix_connectivities'


preprocess = Compose([
    Categorize(['donor', 'Cluster_preprocessed', 'point'], axis='obs'),
    AddAdjMatrix(adj_matrix_loc)
])
transform = AddDesignMatrix('obs/Cluster_preprocessed','obs/donor', adj_matrix_loc,'design_matrix')

In [8]:
category_to_iterate = 'point'

In [9]:
adata.obs['point'] # note that the dtype is not categorical.

59191    scMEP_point_1
59192    scMEP_point_1
59193    scMEP_point_1
59194    scMEP_point_1
59195    scMEP_point_1
             ...      
18510    scMEP_point_9
18511    scMEP_point_9
18512    scMEP_point_9
18513    scMEP_point_9
18514    scMEP_point_9
Name: point, Length: 63747, dtype: category
Categories (58, object): ['scMEP_point_1', 'scMEP_point_2', 'scMEP_point_3', 'scMEP_point_4', ..., 'scMEP_point_55', 'scMEP_point_56', 'scMEP_point_57', 'scMEP_point_58']

In [10]:
a2d = AnnData2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    transform=transform,
)

### Convert AnnData to Data on call

In [11]:
datas = a2d(adata)
datas

[Data(features=[1338, 88], labels=[1338, 36]),
 Data(features=[311, 88], labels=[311, 36]),
 Data(features=[768, 88], labels=[768, 36]),
 Data(features=[1020, 88], labels=[1020, 36]),
 Data(features=[2100, 88], labels=[2100, 36]),
 Data(features=[1325, 88], labels=[1325, 36]),
 Data(features=[1091, 88], labels=[1091, 36]),
 Data(features=[1046, 88], labels=[1046, 36]),
 Data(features=[618, 88], labels=[618, 36]),
 Data(features=[61, 88], labels=[61, 36]),
 Data(features=[1316, 88], labels=[1316, 36]),
 Data(features=[1540, 88], labels=[1540, 36]),
 Data(features=[1822, 88], labels=[1822, 36]),
 Data(features=[863, 88], labels=[863, 36]),
 Data(features=[564, 88], labels=[564, 36]),
 Data(features=[1023, 88], labels=[1023, 36]),
 Data(features=[324, 88], labels=[324, 36]),
 Data(features=[287, 88], labels=[287, 36]),
 Data(features=[636, 88], labels=[636, 36]),
 Data(features=[890, 88], labels=[890, 36]),
 Data(features=[1235, 88], labels=[1235, 36]),
 Data(features=[1020, 88], labels=[

In [12]:
datas[0].features, datas[0].features.shape

(tensor([[0., 0., 0.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 torch.Size([1338, 88]))

## Squidpy Datasets

In [13]:
adata = sq.datasets.mibitof()
adata

AnnData object with n_obs × n_vars = 3309 × 36
    obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'
    var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'
    uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'
    obsm: 'X_scanorama', 'X_umap', 'spatial'
    obsp: 'connectivities', 'distances'

In [16]:
# fields = {
#     'features':['obs/Cluster','obs/donor'],
#     'labels':['X']
# }

fields = {
    'features':['obs/Cluster','obs/donor'],
    'labels':['X'],
    'edge_index':['uns/edge_index'],
}

# preprocess = [
#     lambda x,_: transforms.add_design_matrix(x,'obs/Cluster','obs/donor','design_matrix')
# ]

adj_matrix_loc = 'obsp/connectivities'

from geome.transforms import AddEdgeIndex
transform = Compose([
    AddEdgeIndex(adj_matrix_loc,'edge_index', overwrite=True),
    AddDesignMatrix('obs/Cluster','obs/donor', adj_matrix_loc,'design_matrix'),

])


category_to_iterate = 'library_id'

a2c = AnnData2DataByCategory(fields=fields,
                             category=category_to_iterate,
                             preprocess=None,
                             transform=transform)

In [17]:
datas = a2c(adata)
datas

[Data(edge_index=[2, 8878], features=[1023, 10], labels=[1023, 36]),
 Data(edge_index=[2, 17770], features=[1241, 10], labels=[1241, 36]),
 Data(edge_index=[2, 3944], features=[1045, 10], labels=[1045, 36])]

In [18]:
datas[0].labels, datas[0].labels.shape

(array([[-0.01463166, -0.2530882 , -0.07003902, ..., -0.13318786,
         -0.06862581, -0.19842437],
        [-0.25636178, -0.0944411 , -0.04101864, ..., -0.10532203,
         -0.02109102, -0.10195286],
        [-0.32266545, -0.22462557, -0.0605919 , ..., -0.17151919,
         -0.06436188, -0.04061128],
        ...,
        [-0.14503248, -0.03822452,  0.07996222, ..., -0.29326773,
         -0.25504532,  0.253582  ],
        [-0.11058567, -0.28841913, -0.09688035, ..., -0.38152164,
         -0.11625874,  0.08042709],
        [-0.09426577, -0.1985481 , -0.06324905, ..., -0.3534353 ,
         -0.0251825 , -0.05501218]], dtype=float32),
 (1023, 36))