In [1]:
%load_ext autoreload
%autoreload 2

import ncem
import matplotlib.pyplot as plt
import pandas as pd
import os.path as osp
import scanpy as sc
import pytorch_lightning as pl
sc.settings.set_figure_params(dpi=80)
import warnings
warnings.filterwarnings("ignore")


import torch_geometric,torch
from torch_geometric.data import Data, Dataset, InMemoryDataset, LightningDataset
from torch_geometric.loader import NeighborLoader


In [2]:
class HartmannWrapper(Dataset):
    def __init__(self, root='./data', transform=None, pre_transform=None, pre_filter=None):
        self.img_count = 58
        self.root = root
        super().__init__(root, transform, pre_transform, pre_filter)
    
    # raw file name
    @property
    def raw_file_names(self):
        return ['scMEP_MIBI_singlecell/scMEP_MIBI_singlecell.csv']

    
    # Each graph is saved as a file
    # Is this ideal?
    @property
    def processed_file_names(self):
        return [f'data_{idx}.pt' for idx in range(self.img_count)]


    def process(self):
        
        # Read from already implemented DataLoader to load to np/cpu
        self.interpreter = ncem.interpretation.interpreter.InterpreterInteraction()
        ip = self.interpreter
        # Read data from `raw_path`.
        ip.get_data(
            data_origin='hartmann',
            data_path=self.root,
            radius=35,
            node_label_space_id='type',
            node_feature_space_id='standard',
        )
        for idx,k in enumerate(ip.a.keys()):
            a,h_0,h_1,domains,node_covar,sf = ip.a[k],ip.h_0[k],ip.h_1[k],ip.domains[k],ip.node_covar[k],ip.size_factors[k]
            
            # creating torch sparse tensor
            row,col = a.nonzero()
            a_sparse = torch.sparse.LongTensor(torch.LongTensor([row.tolist(), col.tolist()]),
                              torch.ones(len(row),dtype=torch.long)).coalesce()
            
            # one hot vector
            g = torch.zeros(ip.n_domains)
            g[domains] = 1
            
            # creating data object
            data = Data(x=a_sparse,
                        g=g,
                        sf=torch.from_numpy(sf),
                        h_0=torch.from_numpy(h_0),
                        h_1=torch.from_numpy(h_1),
                        edge_index=a_sparse.indices()
                       )
            # saving it as file
            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [3]:
dataset = HartmannWrapper()
dataset[0]

Data(x=[1338, 1338], edge_index=[2, 6636], g=[58], sf=[1338], h_0=[1338, 8], h_1=[1338, 36])

In [53]:
loader = NeighborLoader(
    dataset[0],
    # Sample 30 neighbors for each node for 2 iterations
    num_neighbors=[30] * 2,
    # Use a batch size of 128 for sampling training nodes
    batch_size=128,
)


In [54]:
for neighs in loader:
    print(neighs)

Data(x=[175, 1338], edge_index=[2, 623], g=[58], sf=[175], h_0=[175, 8], h_1=[175, 36], batch_size=128)
Data(x=[234, 1338], edge_index=[2, 918], g=[58], sf=[234], h_0=[234, 8], h_1=[234, 36], batch_size=128)
Data(x=[240, 1338], edge_index=[2, 1028], g=[58], sf=[240], h_0=[240, 8], h_1=[240, 36], batch_size=128)
Data(x=[232, 1338], edge_index=[2, 957], g=[58], sf=[232], h_0=[232, 8], h_1=[232, 36], batch_size=128)
Data(x=[263, 1338], edge_index=[2, 972], g=[58], sf=[263], h_0=[263, 8], h_1=[263, 36], batch_size=128)
Data(x=[253, 1338], edge_index=[2, 1065], g=[58], sf=[253], h_0=[253, 8], h_1=[253, 36], batch_size=128)
Data(x=[247, 1338], edge_index=[2, 996], g=[58], sf=[247], h_0=[247, 8], h_1=[247, 36], batch_size=128)
Data(x=[243, 1338], edge_index=[2, 920], g=[58], sf=[243], h_0=[243, 8], h_1=[243, 36], batch_size=128)
Data(x=[259, 1338], edge_index=[2, 950], g=[58], sf=[259], h_0=[259, 8], h_1=[259, 36], batch_size=128)
Data(x=[227, 1338], edge_index=[2, 1051], g=[58], sf=[227], h_

In [None]:
ldataset = LightningDataset(train_dataset=dataset)
ldataset.train_dataset[0].sf