# Creation of torch-geometric data objects from AnnData

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
import warnings
import anndata as ad
import squidpy as sq
from geome import transforms
from geome.anndata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann

warnings.filterwarnings('ignore')


## All NCEM Datasets

### Load Unprocessed Dataset

In [7]:
# Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())

Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.


### Some Preprocessing Done Manually

These processing steps can also be done in the a2c callable if they are given as functions in to the preprocess list

In [8]:
# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)

In [9]:
adata

AnnData object with n_obs × n_vars = 63747 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'

### Creating A2D


In [10]:
fields = {
    'features':['obs/Cluster_preprocessed','obs/donor','obsm/design_matrix'],
    'labels':['X']
}

Here we list the preprocessing steps that we need to be done on anndata. They take two parameters the `adata` and the `fields`. But we don't use the `fields` in this example.

In [11]:
preprocess = [
    lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
    lambda x,_: transforms.add_design_matrix(x,'obs/Cluster_preprocessed','obs/donor','design_matrix'),
]

In [12]:
category_to_iterate = 'point'

In [13]:
adata.obs['point'] # note that the dtype is not categorical.

59191    scMEP_point_1
59192    scMEP_point_1
59193    scMEP_point_1
59194    scMEP_point_1
59195    scMEP_point_1
             ...      
18510    scMEP_point_9
18511    scMEP_point_9
18512    scMEP_point_9
18513    scMEP_point_9
18514    scMEP_point_9
Name: point, Length: 63747, dtype: object

In [14]:
a2d = AnnData2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    yields_edge_index=False, 
)

### Convert AnnData to Data on call

In [15]:
datas = a2d(adata)
datas

[Data(features=[1338, 88], labels=[1338, 36]),
 Data(features=[61, 88], labels=[61, 36]),
 Data(features=[1316, 88], labels=[1316, 36]),
 Data(features=[1540, 88], labels=[1540, 36]),
 Data(features=[1822, 88], labels=[1822, 36]),
 Data(features=[863, 88], labels=[863, 36]),
 Data(features=[564, 88], labels=[564, 36]),
 Data(features=[1023, 88], labels=[1023, 36]),
 Data(features=[324, 88], labels=[324, 36]),
 Data(features=[287, 88], labels=[287, 36]),
 Data(features=[636, 88], labels=[636, 36]),
 Data(features=[311, 88], labels=[311, 36]),
 Data(features=[890, 88], labels=[890, 36]),
 Data(features=[1235, 88], labels=[1235, 36]),
 Data(features=[1020, 88], labels=[1020, 36]),
 Data(features=[1241, 88], labels=[1241, 36]),
 Data(features=[1438, 88], labels=[1438, 36]),
 Data(features=[1021, 88], labels=[1021, 36]),
 Data(features=[1632, 88], labels=[1632, 36]),
 Data(features=[780, 88], labels=[780, 36]),
 Data(features=[524, 88], labels=[524, 36]),
 Data(features=[669, 88], labels=[6

In [16]:
datas[0].features, datas[0].features.shape

(tensor([[0., 0., 0.,  ..., 1., 0., 0.],
         [0., 1., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 torch.Size([1338, 88]))

## Squidpy Datasets

In [17]:
adata = sq.datasets.mibitof()
adata

AnnData object with n_obs × n_vars = 3309 × 36
    obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'
    var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'
    uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'
    obsm: 'X_scanorama', 'X_umap', 'spatial'
    obsp: 'connectivities', 'distances'

In [18]:
fields = {
    'features':['obs/Cluster','obs/donor'],
    'labels':['X']
}


preprocess = [
    lambda x,_: transforms.add_design_matrix(x,'obs/Cluster','obs/donor','design_matrix')
]

category_to_iterate = 'library_id'

a2c = AnnData2DataByCategory(fields=fields,
                             category=category_to_iterate,
                             preprocess=preprocess,
                             yields_edge_index=True)

In [19]:
datas = a2c(adata)
datas

[Data(edge_index=[2, 6138], features=[1023, 10], labels=[1023, 36]),
 Data(edge_index=[2, 7446], features=[1241, 10], labels=[1241, 36]),
 Data(edge_index=[2, 6270], features=[1045, 10], labels=[1045, 36])]

In [20]:
datas[0].labels, datas[0].labels.shape

(tensor([[-0.0146, -0.2531, -0.0700,  ..., -0.1332, -0.0686, -0.1984],
         [-0.2564, -0.0944, -0.0410,  ..., -0.1053, -0.0211, -0.1020],
         [-0.3227, -0.2246, -0.0606,  ..., -0.1715, -0.0644, -0.0406],
         ...,
         [-0.1450, -0.0382,  0.0800,  ..., -0.2933, -0.2550,  0.2536],
         [-0.1106, -0.2884, -0.0969,  ..., -0.3815, -0.1163,  0.0804],
         [-0.0943, -0.1985, -0.0632,  ..., -0.3534, -0.0252, -0.0550]]),
 torch.Size([1023, 36]))