In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from argparse import ArgumentParser, Namespace
from random import choices
import pytorch_lightning as pl
from typing import Callable, List, Optional, Sequence, Union
import squidpy as sq
import torch
from torch_geometric.loader import RandomNodeSampler
import pandas as pd
from torch_geometric.data import Data
from anndata import AnnData
from gpu_spatial_graph_pipeline.utils import adata2data
from gpu_spatial_graph_pipeline.data.datamodule import GraphAnnDataModule
from gpu_spatial_graph_pipeline.models.linear_ncem import LinearNCEM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#Mibitof
adata = sq.datasets.mibitof()
#feature_name=adata.obs.keys()[0] #Use for IMC dataset

#specify features to use
feature_names=['Cluster','batch']

def mibitof2data(adata):
    return adata2data(adata,feature_names)


#input of datamodule
num_features=(len(set(adata.obs[feature_names[0]])),len(set(adata.obs[feature_names[1]])))

num_genes=adata.X.shape[1]


In [8]:
dm = GraphAnnDataModule(adata=adata, adata2data_fn=mibitof2data, num_workers = 8, batch_size=40,learning_type='nodewise')

In [7]:
dm.setup()

In [8]:
itr = 4
for batch in dm.train_dataloader():
    print(batch)
    itr -= 1
    if itr<0:
        break

DataBatch(x=[276, 11], edge_index=[2, 240], y=[276, 36], batch=[276], ptr=[4], train_mask=[276], val_mask=[276], test_mask=[276], batch_size=40)
DataBatch(x=[271, 11], edge_index=[2, 240], y=[271, 36], batch=[271], ptr=[4], train_mask=[271], val_mask=[271], test_mask=[271], batch_size=40)
DataBatch(x=[268, 11], edge_index=[2, 240], y=[268, 36], batch=[268], ptr=[4], train_mask=[268], val_mask=[268], test_mask=[268], batch_size=40)
DataBatch(x=[278, 11], edge_index=[2, 240], y=[278, 36], batch=[278], ptr=[4], train_mask=[278], val_mask=[278], test_mask=[278], batch_size=40)
DataBatch(x=[253, 11], edge_index=[2, 240], y=[253, 36], batch=[253], ptr=[4], train_mask=[253], val_mask=[253], test_mask=[253], batch_size=40)


In [9]:
model = LinearNCEM(in_channels=num_features,out_channels=num_genes, model_type='spatial', lr=0.0001,weight_decay=0.000001)

In [12]:
gpu=False
if gpu:
    trainer:pl.Trainer = pl.Trainer(accelerator='gpu',max_epochs=1000,log_every_n_steps=10)
else:
    trainer:pl.Trainer = pl.Trainer(accelerator='cpu',max_epochs=1000,log_every_n_steps=10)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model,datamodule=dm)


  | Name        | Type            | Params
------------------------------------------------
0 | model_sigma | LinearSpatial   | 2.7 K 
1 | model_mu    | LinearSpatial   | 2.7 K 
2 | loss_module | GaussianNLLLoss | 0     
------------------------------------------------
5.5 K     Trainable params
0         Non-trainable params
5.5 K     Total params
0.022     Total estimated model params size (MB)


Epoch 999: 100%|██████████| 80/80 [00:03<00:00, 25.98it/s, loss=-1.72, v_num=2, val_r2_score=0.176, val_loss=-1.73] 


In [11]:
trainer.test(model, datamodule=dm)

Testing DataLoader 0: 100%|██████████| 5/5 [00:00<00:00, 55.65it/s]


[{'test_r2_score': -0.2056509477392449, 'test_loss': -0.11146972328424454}]