# Linear NCEM Example

In [10]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import transforms 
from geome.anndata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann
from utils.models.linear_ncem import LinearNCEM
from geome.datamodule import GraphAnnDataModule

In [12]:
fields = {
    'x':['obs/Cluster_preprocessed','obs/donor','obsm/design_matrix'],
    'y':['X']
}


preprocess = [
    lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
    lambda x,_: transforms.add_design_matrix(x,'obs/Cluster_preprocessed','obs/donor','design_matrix'),
]

category_to_iterate = 'point'

a2d = AnnData2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    yields_edge_index=True,
)


#Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())

# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)

datas = a2d(adata)

Loading data from raw files
registering celldata




collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.


In [13]:
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels

(88, 36)

In [14]:
dm = GraphAnnDataModule(datas=datas, num_workers = 12, batch_size=12,learning_type='node')
model = LinearNCEM(in_channels=num_features,out_channels=out_channels, lr=0.0001,weight_decay=0.000001)

In [15]:
trainer:pl.Trainer = pl.Trainer(accelerator='gpu' if torch.torch.cuda.is_available() else 'cpu',
                                max_epochs=100)



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
trainer.fit(model,datamodule=dm)

Missing logger folder: /home/marcella/Documents/geome/docs/notebooks/lightning_logs

  | Name        | Type            | Params
------------------------------------------------
0 | model_sigma | Linear          | 3.2 K 
1 | model_mu    | Linear          | 3.2 K 
2 | loss_module | GaussianNLLLoss | 0     
------------------------------------------------
6.4 K     Trainable params
0         Non-trainable params
6.4 K     Total params
0.026     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [17]:
trainer.test(model, datamodule=dm)

Testing: 0it [00:00, ?it/s]

[{'test_r2_score': -5.123299277161653, 'test_loss': -1.33883798122406}]

In [None]:
x = ad.concat(adata)

In [None]:
x.obs.donor

59191    90de
59192    90de
59193    90de
59194    90de
59195    90de
         ... 
18510    90de
18511    90de
18512    90de
18513    90de
18514    90de
Name: donor, Length: 63747, dtype: object

In [None]:
cats = x.obs.donor.unique()
for cat in cats:
    print(x[x.obs.donor == cat])

View of AnnData object with n_obs × n_vars = 18943 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'
View of AnnData object with n_obs × n_vars = 22224 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'
View of AnnData object with n_obs × n_vars = 5811 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'
View of AnnData object with n_obs × n_vars = 16769 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'


In [None]:
x.obs.donor.unique()

59191    90de
59192    90de
59193    90de
59194    90de
59195    90de
         ... 
18510    90de
18511    90de
18512    90de
18513    90de
18514    90de
Name: donor, Length: 63747, dtype: object