# Non-Linear NCEM Example

In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import pytorch_lightning as pl
import torch
import anndata as ad
from geome import ann2data, transforms
import warnings
from utils.datasets import DatasetHartmann  # example dataset
from utils.models.non_linear_ncem import NonLinearNCEM  # example model
from utils import datamodule

In [8]:
fields = {
    "x": ["obs/Cluster_preprocessed", "obs/donor"],
    "edge_index": ["uns/edge_index"],
    "y": ["X"],
}



preprocess = transforms.Categorize(["donor", "Cluster_preprocessed", "point"], axis="obs")
transform = transforms.AddEdgeIndex(
    edge_index_key="edge_index",
    spatial_key="spatial",
    key_added="spatial",
    func_args={"n_neighs": 10},
)


category_to_iterate = "point"

a2d = ann2data.Ann2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    transform=transform,
)


# Mibitof
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    dataset = DatasetHartmann(data_path="./example_data/hartmann/")
    adatas = list(dataset.img_celldata.values())

# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)

datas = list(a2d(adata))
datas

Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.


[Data(x=[1338, 12], edge_index=[2, 13380], y=[1338, 36]),
 Data(x=[311, 12], edge_index=[2, 3110], y=[311, 36]),
 Data(x=[768, 12], edge_index=[2, 7680], y=[768, 36]),
 Data(x=[1020, 12], edge_index=[2, 10200], y=[1020, 36]),
 Data(x=[2100, 12], edge_index=[2, 21000], y=[2100, 36]),
 Data(x=[1325, 12], edge_index=[2, 13250], y=[1325, 36]),
 Data(x=[1091, 12], edge_index=[2, 10910], y=[1091, 36]),
 Data(x=[1046, 12], edge_index=[2, 10460], y=[1046, 36]),
 Data(x=[618, 12], edge_index=[2, 6180], y=[618, 36]),
 Data(x=[61, 12], edge_index=[2, 610], y=[61, 36]),
 Data(x=[1316, 12], edge_index=[2, 13160], y=[1316, 36]),
 Data(x=[1540, 12], edge_index=[2, 15400], y=[1540, 36]),
 Data(x=[1822, 12], edge_index=[2, 18220], y=[1822, 36]),
 Data(x=[863, 12], edge_index=[2, 8630], y=[863, 36]),
 Data(x=[564, 12], edge_index=[2, 5640], y=[564, 36]),
 Data(x=[1023, 12], edge_index=[2, 10230], y=[1023, 36]),
 Data(x=[324, 12], edge_index=[2, 3240], y=[324, 36]),
 Data(x=[287, 12], edge_index=[2, 2870

In [9]:
num_features = datas[0].x.shape[1]
out_channels = datas[0].y.shape[1]
num_features, out_channels

(12, 36)

In [10]:
dm = datamodule.GraphAnnDataModule(datas=datas, num_workers=12, batch_size=100, learning_type="node")
model = NonLinearNCEM(
    in_channels=num_features,
    out_channels=out_channels,
    encoder_hidden_dims=[16],
    decoder_hidden_dims=[16],
    latent_dim=14,
    lr=0.001,
    weight_decay=0.00001,
)

In [11]:
trainer: pl.Trainer = pl.Trainer(
    accelerator="gpu" if torch.torch.cuda.is_available() else "cpu", max_epochs=100, log_every_n_steps=10
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/selman.ozleyen/mambaforge/envs/geome/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/selman.ozleyen/mambaforge/envs/geome/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [12]:
trainer.fit(model, datamodule=dm)


  | Name          | Type            | Params
--------------------------------------------------
0 | encoder       | GNNModel        | 446   
1 | decoder_sigma | MLPModel        | 852   
2 | decoder_mu    | MLPModel        | 852   
3 | loss_module   | GaussianNLLLoss | 0     
--------------------------------------------------
2.1 K     Trainable params
0         Non-trainable params
2.1 K     Total params
0.009     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/selman.ozleyen/mambaforge/envs/geome/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/selman.ozleyen/mambaforge/envs/geome/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 1:   0%|          | 0/542 [00:00<?, ?it/s, v_num=16, val_r2_score=-7.78, val_loss=75.20]          



Epoch 1:  83%|████████▎ | 449/542 [00:11<00:02, 40.17it/s, v_num=16, val_r2_score=-7.78, val_loss=75.20]

/Users/selman.ozleyen/mambaforge/envs/geome/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [13]:
trainer.test(model, datamodule=dm)

/Users/selman.ozleyen/mambaforge/envs/geome/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing DataLoader 0: 100%|██████████| 32/32 [00:00<00:00, 220.52it/s]


[{'test_r2_score': -8.196203231811523, 'test_loss': 42.403907775878906}]