In [1]:
import sys, os
module_paths = [os.path.abspath('.'), os.path.abspath('..')]
for module_path in module_paths:
    if module_path not in sys.path:
        sys.path.append(module_path)

import numpy as np
import pandas as pd
import scanpy as sc
import networkx as nx

import torch
import torch_geometric as pyg

In [2]:
dataset = "marson" #"sciplex", "L008"

# 1. initial
graph = "grn.pkl"
dimension = 8

In [3]:
'''
# 2. updated
graph = "%s_grn.pkl" % dataset
dimension = 64
'''

'\n# 2. updated\ngraph = "%s_grn.pkl" % dataset\ndimension = 64\n'

In [4]:
adata = sc.read('../datasets/%s_prepped.h5ad' % dataset)
adata.var

Unnamed: 0,ENSGNM,name,is_sgRNA,mito,highly_variable,means,dispersions,dispersions_norm
ISG15,ENSG00000187608,ISG15,False,False,True,0.464187,0.646769,2.577193
TNFRSF18,ENSG00000186891,TNFRSF18,False,False,True,0.069438,0.918909,6.711458
TNFRSF4,ENSG00000186827,TNFRSF4,False,False,True,0.220937,1.205736,19.591536
SLC35E2A,ENSG00000215790,SLC35E2A,False,False,True,0.054039,0.745175,1.580994
PLCH2,ENSG00000149527,PLCH2,False,False,True,0.002080,0.842152,1.986914
...,...,...,...,...,...,...,...,...
UTY,ENSG00000183878,UTY,False,False,True,0.063287,0.762314,2.413199
TTTY14,ENSG00000176728,TTTY14,False,False,True,0.008703,0.807486,1.762861
KDM5D,ENSG00000012817,KDM5D,False,False,True,0.047676,0.753803,1.807152
TTTY10,ENSG00000229236,TTTY10,False,False,True,0.007920,0.810274,1.819597


In [5]:
nodes = np.array(adata.var.index)
nodes

array(['ISG15', 'TNFRSF18', 'TNFRSF4', ..., 'KDM5D', 'TTTY10', 'EIF1AY'],
      dtype=object)

In [6]:
source_path = "./node/source/%s" % dataset
export_path = "./node/export/%s" % dataset
os.makedirs(source_path, exist_ok=True)
os.makedirs(export_path, exist_ok=True)

In [7]:
np.savetxt(os.path.join(source_path, "genes.txt"),
    np.expand_dims(nodes, 1), fmt="%s"
)

In [8]:
from node.gene2vec import gene2vec

gene2vec(os.path.abspath(source_path), os.path.abspath(export_path), 'txt', dimension=dimension)

2023-04-13 00:01:55.379201
current file genes.txt num: 1 total files 1
2023-04-13 00:01:55.381714
shuffle start 2013
2023-04-13 00:01:55.382906
shuffle done 2013
gene2vec dimension 8 iteration 1 start
gene2vec dimension 8 iteration 1 done
2023-04-13 00:01:56.125130
shuffle start 2013
2023-04-13 00:01:56.126350
shuffle done 2013
gene2vec dimension 8 iteration 2 start
gene2vec dimension 8 iteration 2 done
2023-04-13 00:01:56.487606
shuffle start 2013
2023-04-13 00:01:56.488896
shuffle done 2013
gene2vec dimension 8 iteration 3 start
gene2vec dimension 8 iteration 3 done
2023-04-13 00:01:56.717192
shuffle start 2013
2023-04-13 00:01:56.718600
shuffle done 2013
gene2vec dimension 8 iteration 4 start
gene2vec dimension 8 iteration 4 done
2023-04-13 00:01:56.963208
shuffle start 2013
2023-04-13 00:01:56.964628
shuffle done 2013
gene2vec dimension 8 iteration 5 start
gene2vec dimension 8 iteration 5 done
2023-04-13 00:01:57.138769
shuffle start 2013
2023-04-13 00:01:57.140060
shuffle done 201

In [9]:
nodes_emb = pd.read_csv(os.path.join(export_path, "gene2vec_dim_%s_iter_10.txt") % dimension, sep=" ", index_col=0, header=None)
nodes_emb

Unnamed: 0_level_0,1,2,3,4,5,6,7,8
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TRDC,-0.014802,0.044940,-0.048579,0.022127,0.009083,-0.035273,-0.034215,-0.013072
AC008555.4,-0.027248,-0.000775,0.046132,-0.014350,0.036312,-0.003436,-0.031057,0.051704
ZNF91,0.056041,-0.030958,0.039535,-0.008874,-0.013574,0.021708,-0.035313,-0.021563
CPM,0.029942,-0.061144,-0.008050,-0.046416,-0.056540,0.012549,0.015265,-0.014975
AL137779.1,0.050672,-0.029744,-0.026961,0.000353,-0.008679,-0.032039,-0.004612,-0.027411
...,...,...,...,...,...,...,...,...
SERPINB9,-0.026467,0.034402,0.044665,-0.000727,-0.035947,-0.017527,-0.058751,-0.032231
TMOD2,-0.017628,-0.050379,0.011948,-0.002261,0.047951,0.039177,0.046972,-0.057398
CD2,-0.003123,-0.059434,-0.000001,-0.039323,-0.004405,-0.017672,-0.032516,0.000295
DISC1,0.012189,-0.019545,-0.032533,0.015102,0.016515,-0.024574,-0.002918,-0.011705


In [10]:
nodes_emb_dict = nodes_emb.T.to_dict('list')

In [11]:
import pickle
'''
from src.utils.graph_utils import parse_grn

graph_df = pd.read_parquet('hg19_TFinfo_dataframe_gimmemotifsv5_fpr2_threshold_10_20210630.parquet')
nx_graph = parse_grn(graph_df, 'gene_short_name')

pickle.dump(nx_graph, open('./grn.pkl', 'wb'))
'''
with open(graph, 'rb') as f:
    nx_graph = pickle.load(f)

In [12]:
for n in nodes:
    if not nx_graph.has_node(n):
        nx_graph.add_node(n)
def filter_node(n):
    return n in nodes
nx_graph = nx.subgraph_view(nx_graph, filter_node)

In [13]:
nodes_dict = dict(zip(nodes, range(len(nodes))))

In [14]:
x = []
edge_index = []
for i, node in enumerate(nodes):
    x.append(nodes_emb_dict[node])

    edges = list(nx_graph.in_edges(node))
    edges = [(nodes_dict[n[0]], i) for n in edges]
    edge_index.extend(edges)

In [15]:
x = torch.Tensor(x)
x.size()

torch.Size([2013, 8])

In [16]:
edge_index = torch.LongTensor(edge_index)
edge_index.size()

torch.Size([42563, 2])

In [17]:
grn = pyg.data.Data(x, edge_index.t())
torch.save(grn, '%s_grn_%s.pth' % (dataset, dimension))

In [18]:
torch.load('%s_grn_%s.pth' % (dataset, dimension))

Data(x=[2013, 8], edge_index=[2, 42563])