# Non-Linear NCEM Example

In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import transforms
from geome.anndata2data import AnnData2DataByCategory
from utils.datasets import DatasetHartmann
from utils.models.non_linear_ncem import NonLinearNCEM
from geome.datamodule import GraphAnnDataModule


In [14]:
fields = {
    'x':['obs/Cluster_preprocessed','obs/donor'],
    'y':['X']
}


preprocess = [
    lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
]

category_to_iterate = 'point'

a2d = AnnData2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    yields_edge_index=True,
)


#Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())

# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)

datas = a2d(adata)
datas

Loading data from raw files
registering celldata




collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.


[Data(x=[1338, 12], edge_index=[2, 8028], y=[1338, 36]),
 Data(x=[61, 12], edge_index=[2, 366], y=[61, 36]),
 Data(x=[1316, 12], edge_index=[2, 7896], y=[1316, 36]),
 Data(x=[1540, 12], edge_index=[2, 9240], y=[1540, 36]),
 Data(x=[1822, 12], edge_index=[2, 10932], y=[1822, 36]),
 Data(x=[863, 12], edge_index=[2, 5178], y=[863, 36]),
 Data(x=[564, 12], edge_index=[2, 3384], y=[564, 36]),
 Data(x=[1023, 12], edge_index=[2, 6138], y=[1023, 36]),
 Data(x=[324, 12], edge_index=[2, 1944], y=[324, 36]),
 Data(x=[287, 12], edge_index=[2, 1722], y=[287, 36]),
 Data(x=[636, 12], edge_index=[2, 3816], y=[636, 36]),
 Data(x=[311, 12], edge_index=[2, 1866], y=[311, 36]),
 Data(x=[890, 12], edge_index=[2, 5340], y=[890, 36]),
 Data(x=[1235, 12], edge_index=[2, 7410], y=[1235, 36]),
 Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36]),
 Data(x=[1241, 12], edge_index=[2, 7446], y=[1241, 36]),
 Data(x=[1438, 12], edge_index=[2, 8628], y=[1438, 36]),
 Data(x=[1021, 12], edge_index=[2, 6126], y=[1021

In [15]:
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels

(12, 36)

In [16]:
dm = GraphAnnDataModule(datas=datas, num_workers = 12, batch_size=100,learning_type='node')
model = NonLinearNCEM(
    in_channels=num_features,
    out_channels=out_channels,
    encoder_hidden_dims=[16],
    decoder_hidden_dims=[16],
    latent_dim=14,
    lr=0.001,weight_decay=0.00001)

In [17]:
trainer:pl.Trainer = pl.Trainer(accelerator='gpu' if torch.torch.cuda.is_available() else 'cpu',
                                max_epochs=100,log_every_n_steps=10)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(model,datamodule=dm)


  | Name          | Type            | Params
--------------------------------------------------
0 | encoder       | GNNModel        | 446   
1 | decoder_sigma | MLPModel        | 852   
2 | decoder_mu    | MLPModel        | 852   
3 | loss_module   | GaussianNLLLoss | 0     
--------------------------------------------------
2.1 K     Trainable params
0         Non-trainable params
2.1 K     Total params
0.009     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [19]:
trainer.test(model, datamodule=dm)

Testing: 0it [00:00, ?it/s]

[{'test_r2_score': -0.9155278940025104, 'test_loss': 43.77671432495117}]