In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from argparse import ArgumentParser, Namespace
from random import choices
import pytorch_lightning as pl
from typing import Callable, List, Optional, Sequence, Union
import squidpy as sq
import torch
from torch_geometric.loader import RandomNodeSampler
import pandas as pd
from torch_geometric.data import Data
from anndata import AnnData
from gpu_spatial_graph_pipeline.utils import adata2data
from gpu_spatial_graph_pipeline.data.datamodule import GraphAnnDataModule
from gpu_spatial_graph_pipeline.models.linear_ncem import LinearNCEM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#Mibitof
adata = sq.datasets.mibitof()
#feature_name=adata.obs.keys()[0] #Use for IMC dataset

#specify features to use
feature_names=['Cluster','batch']

#input of datamodule
num_features=(len(set(adata.obs[feature_names[0]])),len(set(adata.obs[feature_names[1]])))
num_genes=adata.X.shape[1]


In [4]:
dm = GraphAnnDataModule(adata=adata, feature_names=feature_names, adata2data_fn=adata2data, num_workers = 16, batch_size=40,learning_type='nodewise')

In [5]:
dm.setup()

In [6]:
for batch in dm.train_dataloader():
    print(batch)

DataBatch(x=[259, 11], edge_index=[2, 240], y=[259, 36], batch=[259], ptr=[4], train_mask=[259], val_mask=[259], test_mask=[259], batch_size=40)
DataBatch(x=[271, 11], edge_index=[2, 240], y=[271, 36], batch=[271], ptr=[4], train_mask=[271], val_mask=[271], test_mask=[271], batch_size=40)
DataBatch(x=[277, 11], edge_index=[2, 240], y=[277, 36], batch=[277], ptr=[4], train_mask=[277], val_mask=[277], test_mask=[277], batch_size=40)
DataBatch(x=[277, 11], edge_index=[2, 240], y=[277, 36], batch=[277], ptr=[4], train_mask=[277], val_mask=[277], test_mask=[277], batch_size=40)
DataBatch(x=[271, 11], edge_index=[2, 240], y=[271, 36], batch=[271], ptr=[4], train_mask=[271], val_mask=[271], test_mask=[271], batch_size=40)
DataBatch(x=[269, 11], edge_index=[2, 240], y=[269, 36], batch=[269], ptr=[4], train_mask=[269], val_mask=[269], test_mask=[269], batch_size=40)
DataBatch(x=[276, 11], edge_index=[2, 240], y=[276, 36], batch=[276], ptr=[4], train_mask=[276], val_mask=[276], test_mask=[276], 

In [7]:
model = LinearNCEM(in_channels=num_features,out_channels=num_genes, model_type='spatial', lr=0.0001,weight_decay=0.000001)

In [8]:
gpu=False
if gpu:
    trainer:pl.Trainer = pl.Trainer(accelerator='gpu',max_epochs=30,log_every_n_steps=1)
else:
    trainer:pl.Trainer = pl.Trainer(accelerator='cpu',max_epochs=30,log_every_n_steps=1)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model,datamodule=dm)


  | Name        | Type            | Params
------------------------------------------------
0 | model_sigma | LinearSpatial   | 2.7 K 
1 | model_mu    | LinearSpatial   | 2.7 K 
2 | loss_module | GaussianNLLLoss | 0     
------------------------------------------------
5.5 K     Trainable params
0         Non-trainable params
5.5 K     Total params
0.022     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 29: 100%|██████████| 84/84 [00:01<00:00, 80.90it/s, loss=-0.462, v_num=38, val_r2_score=0.352, val_loss=-.490]  

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 84/84 [00:01<00:00, 80.56it/s, loss=-0.462, v_num=38, val_r2_score=0.352, val_loss=-.490]


In [11]:
trainer.test(model, datamodule=dm)

  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 119.53it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss          -0.46245187520980835
      test_r2_score         0.3123172731002297
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_r2_score': 0.3123172731002297, 'test_loss': -0.46245187520980835}]