In [1]:
%load_ext autoreload
%autoreload 2

import ncem
import matplotlib.pyplot as plt
import pandas as pd
import os.path as osp
import scanpy as sc
import pytorch_lightning as pl
import squidpy as sq
sc.settings.set_figure_params(dpi=80)
import warnings
warnings.filterwarnings("ignore")
from ncem.data import DataLoaderHartmann

import torch_geometric,torch
from torch_geometric.data import Data, Dataset, InMemoryDataset, LightningDataset
from torch_geometric.loader import NeighborLoader



## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Dataset Class

In [2]:
class HartmannWrapper(Dataset):
    def __init__(self, root='./data', transform=None, pre_transform=None, pre_filter=None):
        self.img_count = 58
        self.root = root
        super().__init__(root, transform, pre_transform, pre_filter)
    
    # raw file name
    @property
    def raw_file_names(self):
        return ['scMEP_MIBI_singlecell/scMEP_MIBI_singlecell.csv']

    
    # Each graph is saved as a file
    # Is this ideal?
    @property
    def processed_file_names(self):
        return [f'data_{idx}.pt' for idx in range(self.img_count)]


    def process(self):
        
        # Read from already implemented DataLoader to load to np/cpu
        self.interpreter = ncem.interpretation.interpreter.InterpreterInteraction()
        ip = self.interpreter
        # Read data from `raw_path`.
        ip.get_data(
            data_origin='hartmann',
            data_path=self.root,
            radius=35,
            node_label_space_id='type',
            node_feature_space_id='standard',
        )
        for idx,k in enumerate(ip.a.keys()):
            a,h_0,h_1,domains,node_covar,sf = ip.a[k],ip.h_0[k],ip.h_1[k],ip.domains[k],ip.node_covar[k],ip.size_factors[k]
            
            # Edge list (2,n_edges)
            row,col = a.nonzero()
            edge_index = torch.LongTensor([row.tolist(), col.tolist()])
            
            
            # X_c from the paper (n_domains=58)
            # one hot vector
            # g = torch.zeros(ip.n_domains)
            # g[domains] = 1
            # Note: Not used since dataloader encodes the batch id
            
            # X_l from paper (n_node, n_celltypes=8) n_celltypes: count of distinct cell-type labels
            h_0=torch.from_numpy(h_0)
            
            # size factor
            sf=torch.from_numpy(sf)
            
            # Y from paper (n_node, n_genes)
            h_1 = torch.from_numpy(h_1)
            
            # x for pygeometric convention is node to features
            x = torch.hstack((h_0,h_1)).to(torch.float32)
            
            
            # creating data object
            data = Data(sf=sf,
                        x = x,
                        edge_index=edge_index)
            
            # saving it as file
            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            
            
    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [3]:
dataset = HartmannWrapper()
dataset.process() # uncomment if the class is modified
dataloader = torch_geometric.data.DataLoader(dataset, batch_size=1)

Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58/58 [00:01<00:00, 38.64it/s]


Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.
Mean of mean node degree per images across images: 4.416036


In [4]:
dataset[0]

Data(x=[1338, 44], edge_index=[2, 6636], sf=[1338])

In [5]:
# show first 5 batch
for itr,graph in enumerate(dataloader):
    print(graph)
    if itr>4:
        break

DataBatch(x=[1338, 44], edge_index=[2, 6636], sf=[1338], batch=[1338], ptr=[2])
DataBatch(x=[61, 44], edge_index=[2, 36], sf=[61], batch=[61], ptr=[2])
DataBatch(x=[1316, 44], edge_index=[2, 5622], sf=[1316], batch=[1316], ptr=[2])
DataBatch(x=[1540, 44], edge_index=[2, 7898], sf=[1540], batch=[1540], ptr=[2])
DataBatch(x=[1822, 44], edge_index=[2, 10630], sf=[1822], batch=[1822], ptr=[2])
DataBatch(x=[863, 44], edge_index=[2, 2236], sf=[863], batch=[863], ptr=[2])


In [6]:
example_batch = next(iter(dataloader))

In [7]:
example_batch

DataBatch(x=[1338, 44], edge_index=[2, 6636], sf=[1338], batch=[1338], ptr=[2])

In [8]:
dataset.num_features

44

## Training Model

In [9]:
from torch_geometric.nn import GAE,VGAE,GCNConv

In [10]:
# taken from https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial6/Tutorial6.ipynb
class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        # cached only for transductive learning
        self.conv1 = GCNConv(in_channels, 2 * out_channels) 
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [11]:
# taken from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html
class NodeLevelGNN(pl.LightningModule):

    def __init__(self, hparams, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters(hparams)
        
        self.model = VGAE(VariationalGCNEncoder(hparams['num_features'], hparams['latent_dim']))

    def forward(self, data, mode="train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)

        
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(),lr=self.hparams['lr'],weight_decay=self.hparams['weight_decay'])
        return optimizer

    def training_step(self, batch, batch_idx):
        z = self.model.encode(batch.x,batch.edge_index)
        recon_loss = self.model.recon_loss(z, batch.edge_index)
        kl_loss = self.model.kl_loss()
        
        loss = recon_loss + (1.0 / batch.num_nodes) * kl_loss
        
        self.log('train_recon_loss', loss)
        self.log('train_kl_loss', loss)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x = self.forward(batch, mode="val")
        print(x.shape,batch.shape)
        self.log('val_acc', x)

    def test_step(self, batch, batch_idx):
        raise NotImplementedError()
        _, acc = self.forward(batch, mode="test")
        self.log('test_acc', acc)

In [12]:
hparams = {
    'lr' : 0.01,
    'weight_decay':0,
    'num_features': dataset.num_features,
    'latent_dim' : 30
}
model = NodeLevelGNN(hparams)

In [13]:
trainer = pl.Trainer(
    #  callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss",patience=10,mode='max')],
    max_epochs=200,
    accelerator='gpu',
    devices=1
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(model,dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | VGAE | 6.4 K 
-------------------------------
6.4 K     Trainable params
0         Non-trainable params
6.4 K     Total params
0.025     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]