In [2]:
import networkx as nx
import numpy as np
import pandas as pd
import scipy.sparse as sp
from spektral.data import Dataset, Graph

class scRNA(Dataset):

    def __init__(self, **kwargs):
        self.a = None
        super().__init__(**kwargs)

    def read(self):
        x, y, node_order = _get_scRNA_exprs()
        self.a = _get_adjacency(node_order)

        return [Graph(x=x_, y=y_) for x_, y_ in zip(x, y)]

def _get_adjacency(node_order):
    #Read in edge file
    g = nx.read_adjlist("./5.1.Edge_of_gene_gene_interaction_network.csv",
                        delimiter = ",")
    # Adjacency
    A = nx.adj_matrix(g, weight = None, nodelist = node_order)
    return A

def _get_scRNA_exprs():
    #Read in exprs file
    exprs = pd.read_hdf("./6.Filtered_node_exprs.h5", key = "exprs")
    cell = pd.read_csv("./3.Cell_label.csv", index_col = 0)
    #Transpose the dataframe
    exprs = exprs.T
    #Exprs values reshape
    x = exprs.values.reshape(exprs.shape[0], exprs.shape[1], 1)
    #Cell label
    y = cell["Number_label"].values
    #Node order
    node_order = exprs.columns.to_list()
    return x, y, node_order