In [3]:
%load_ext autoreload
%autoreload 2

import ncem
import matplotlib.pyplot as plt
import pandas as pd
import os.path as osp
import scanpy as sc
import pytorch_lightning as pl
sc.settings.set_figure_params(dpi=80)
import warnings
warnings.filterwarnings("ignore")


import torch_geometric,torch
from torch_geometric.data import Data, Dataset, InMemoryDataset, LightningDataset
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import VGAE


In [4]:
class HartmannWrapper(Dataset):
    def __init__(self, root='./data', transform=None, pre_transform=None, pre_filter=None):
        self.img_count = 58
        self.root = root
        super().__init__(root, transform, pre_transform, pre_filter)
    
    # raw file name
    @property
    def raw_file_names(self):
        return ['scMEP_MIBI_singlecell/scMEP_MIBI_singlecell.csv']

    
    # Each graph is saved as a file
    # Is this ideal?
    @property
    def processed_file_names(self):
        return [f'data_{idx}.pt' for idx in range(self.img_count)]


    def process(self):
        
        # Read from already implemented DataLoader to load to np/cpu
        self.interpreter = ncem.interpretation.interpreter.InterpreterInteraction()
        ip = self.interpreter
        # Read data from `raw_path`.
        ip.get_data(
            data_origin='hartmann',
            data_path=self.root,
            radius=35,
            node_label_space_id='type',
            node_feature_space_id='standard',
        )
        for idx,k in enumerate(ip.a.keys()):
            a,h_0,h_1,domains,node_covar,sf = ip.a[k],ip.h_0[k],ip.h_1[k],ip.domains[k],ip.node_covar[k],ip.size_factors[k]
            
            # Edge list (2,n_edges)
            row,col = a.nonzero()
            edge_index = torch.LongTensor([row.tolist(), col.tolist()])
            
            
            # X_c from the paper (n_domains=58)
            # one hot vector
            g = torch.zeros(ip.n_domains)
            g[domains] = 1
            
            # X_l from paper (n_node, n_celltypes=8) n_celltypes: count of distinct cell-type labels
            h_0=torch.from_numpy(h_0)
            
            # size factor
            sf=torch.from_numpy(sf)
            
            # Y from paper (n_node, n_genes)
            h_1 = torch.from_numpy(h_1)
            
            # creating data object
            data = Data(g=g,
                        sf=sf,
                        h_0=h_0,
                        h_1=h_1,
                        edge_index=edge_index)
            # saving it as file
            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [5]:
dataset = HartmannWrapper()
data = dataset[0]
data

Data(edge_index=[2, 6636], g=[58], sf=[1338], h_0=[1338, 8], h_1=[1338, 36])

In [6]:
dataset.process()

Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58/58 [00:01<00:00, 43.44it/s]

Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.
Mean of mean node degree per images across images: 4.416036





In [7]:
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges

In [8]:
class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels) # cached only for transductive learning
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [9]:

out_channels = 40
num_features = data.h_1.shape[1]
epochs = 100


model = VGAE(VariationalGCNEncoder(num_features, out_channels))  # new line

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = data.h_1.to(device)
train_pos_edge_index = data.edge_index.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [10]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    
    loss = loss + (1 / data.num_nodes) * model.kl_loss()  # new line
    loss.backward()
    optimizer.step()
    return float(loss)



In [11]:
for epoch in range(1, epochs + 1):
    loss = train()
    print(loss)


7.717545032501221
4.558232307434082
3.575956344604492
2.937342405319214
2.481600761413574
2.100987195968628
1.7955822944641113
1.6123396158218384
1.4992790222167969
1.465749979019165
1.4823532104492188
1.479212760925293
1.4713855981826782
1.4611177444458008
1.4399937391281128
1.4218488931655884
1.4179853200912476
1.4139822721481323
1.416288137435913
1.4166896343231201
1.4108033180236816
1.4061412811279297
1.4008482694625854
1.3929495811462402
1.383109450340271
1.3763481378555298
1.370819330215454
1.3599259853363037
1.34652841091156
1.336995005607605
1.323346495628357
1.3131020069122314
1.302992582321167
1.2905007600784302
1.2787847518920898
1.276010274887085
1.264812707901001
1.2548037767410278
1.2538169622421265
1.248529076576233
1.2573899030685425
1.2374838590621948
1.246610403060913
1.2307378053665161
1.2483290433883667
1.2382105588912964
1.2340731620788574
1.2299022674560547
1.2261991500854492
1.2201356887817383
1.2206525802612305
1.2183811664581299
1.2183687686920166
1.20435166358