In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import EdgeConv, NNConv
from torch_geometric.data import Data, DataLoader

In [2]:
import pandas as pd
import os

### Load training samples

In [3]:
files = os.listdir('training-samples')
files = [f for f in files if os.path.isfile('training-samples'+'/'+f)]
len(files)

820

### Data Preparation for GNN

- `edge_attr`: GNN Layer can use edge feature 
    - Edge feature matrix with shape `[num_edges, num_edge_features]`
- `edge_index`: Graph connectivity with shape `[2, num_edges]`

References: https://pytorch-geometric.readthedocs.io/en/2.6.0/get_started/introduction.html#data-handling-of-graphs

**Create Node Label to Integer Index Mapping**
- Extract all unique node labels from dataset (can get this from `union_ppi`)
- Create dictionary that maps each label to a unique integer index

In [4]:
union_ppi = pd.read_csv('processed-data/union_ppi.txt', sep='\t', header=None)
unique_nodes = set(union_ppi[0].tolist() + union_ppi[1].tolist())

In [5]:
# build mapping for label -> int ID
label_id_map = {label: idx for idx, label in enumerate(sorted(unique_nodes))}
num_nodes = len(label_id_map)
print(f"Total unique nodes: {num_nodes}")

Total unique nodes: 17407


**Build the `Data` objects - one for each training sample**

In [6]:
data_folder = 'training-samples/'
data_list = []

for training_sample in files:
    print('Processing: ', training_sample)
    training_df = pd.read_csv(os.path.join(data_folder, training_sample), sep='\t')
    training_data = training_df.values.tolist()
    
    edge_index_list = [[], []]
    edge_attr_list = []  # each edge attribute: [prize]
    edge_labels = []     # this will get passed as data.y
    selected_edges = []  # store indices of nodes for edges with label 1

    for data_sample in training_data:
        node1, node2, prize, flag, label = data_sample
        idx1 = label_id_map[node1]
        idx2 = label_id_map[node2]
        edge_index_list[0].append(idx1)
        edge_index_list[1].append(idx2)
        edge_attr_list.append([prize])
        edge_labels.append(label)
        
        # if edge label = 1, then record its nodes
        if label == 1:
            selected_edges.append((idx1, idx2))
    
    # lists -> torch tensors
    edge_index = torch.tensor(edge_index_list, dtype=torch.long)
    edge_attr = torch.tensor(edge_attr_list, dtype=torch.float)
    edge_labels = torch.tensor(edge_labels, dtype=torch.float).view(-1, 1)
    
    # init node features: one feature per node (0 by default)
    num_nodes = len(label_id_map)
    x = torch.zeros((num_nodes, 1), dtype=torch.float)
    
    # update node features - mark nodes as 1 if they are connected by a selected edge
    for idx1, idx2 in selected_edges:
        x[idx1] = 1
        x[idx2] = 1
    
    data_obj = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=edge_labels)
    data_obj.file_name = training_sample
    data_list.append(data_obj)

Processing:  alanine__aspartate_a_train_2535.csv
Processing:  alanine__aspartate_a_train_3273.csv
Processing:  alanine__aspartate_a_train_3711.csv
Processing:  alanine__aspartate_a_train_3781.csv
Processing:  alanine__aspartate_a_train_6130.csv
Processing:  alanine__aspartate_a_train_6672.csv
Processing:  alanine__aspartate_a_train_6727.csv
Processing:  alanine__aspartate_a_train_8527.csv
Processing:  alanine__aspartate_a_train_9941.csv
Processing:  alpha_linolenic_acid_train_1854.csv
Processing:  alpha_linolenic_acid_train_2371.csv
Processing:  alpha_linolenic_acid_train_4263.csv
Processing:  alpha_linolenic_acid_train_4939.csv
Processing:  alpha_linolenic_acid_train_5635.csv
Processing:  alpha_linolenic_acid_train_5783.csv
Processing:  alpha_linolenic_acid_train_6633.csv
Processing:  alpha_linolenic_acid_train_7454.csv
Processing:  alpha_linolenic_acid_train_9373.csv
Processing:  alpha_linolenic_acid_train_9806.csv
Processing:  aminoacyl_trna_biosy_train_1565.csv
Processing:  aminoac

Processing:  d_glutamine_and_d_gl_train_3853.csv
Processing:  d_glutamine_and_d_gl_train_5918.csv
Processing:  d_glutamine_and_d_gl_train_6415.csv
Processing:  d_glutamine_and_d_gl_train_6453.csv
Processing:  d_glutamine_and_d_gl_train_6575.csv
Processing:  d_glutamine_and_d_gl_train_7147.csv
Processing:  d_glutamine_and_d_gl_train_7653.csv
Processing:  d_glutamine_and_d_gl_train_8627.csv
Processing:  d_glutamine_and_d_gl_train_9192.csv
Processing:  d_glutamine_and_d_gl_train_9381.csv
Processing:  drug_metabolism___cy_train_1576.csv
Processing:  drug_metabolism___cy_train_2581.csv
Processing:  drug_metabolism___cy_train_2665.csv
Processing:  drug_metabolism___cy_train_3545.csv
Processing:  drug_metabolism___cy_train_4503.csv
Processing:  drug_metabolism___cy_train_6262.csv
Processing:  drug_metabolism___cy_train_7207.csv
Processing:  drug_metabolism___cy_train_7291.csv
Processing:  drug_metabolism___cy_train_7643.csv
Processing:  drug_metabolism___cy_train_7684.csv
Processing:  drug_me

Processing:  glycosaminoglycan_bi_train_9258.csv
Processing:  glycosaminoglycan_bi_train_9612.csv
Processing:  glycosaminoglycan_bi_train_9670.csv
Processing:  glycosaminoglycan_de_train_1224.csv
Processing:  glycosaminoglycan_de_train_1711.csv
Processing:  glycosaminoglycan_de_train_2488.csv
Processing:  glycosaminoglycan_de_train_3162.csv
Processing:  glycosaminoglycan_de_train_3395.csv
Processing:  glycosaminoglycan_de_train_4549.csv
Processing:  glycosaminoglycan_de_train_4714.csv
Processing:  glycosaminoglycan_de_train_7274.csv
Processing:  glycosaminoglycan_de_train_8470.csv
Processing:  glycosaminoglycan_de_train_9198.csv
Processing:  glycosphingolipid_bi__1_train_1944.csv
Processing:  glycosphingolipid_bi__1_train_2264.csv
Processing:  glycosphingolipid_bi__1_train_2271.csv
Processing:  glycosphingolipid_bi__1_train_3112.csv
Processing:  glycosphingolipid_bi__1_train_4340.csv
Processing:  glycosphingolipid_bi__1_train_4541.csv
Processing:  glycosphingolipid_bi__1_train_6175.csv

Processing:  nitrogen_metabolism_train_7075.csv
Processing:  nitrogen_metabolism_train_7440.csv
Processing:  nitrogen_metabolism_train_8986.csv
Processing:  nitrogen_metabolism_train_9395.csv
Processing:  nitrogen_metabolism_train_9476.csv
Processing:  nitrogen_metabolism_train_9511.csv
Processing:  one_carbon_pool_by_f_train_1293.csv
Processing:  one_carbon_pool_by_f_train_1466.csv
Processing:  one_carbon_pool_by_f_train_1656.csv
Processing:  one_carbon_pool_by_f_train_1671.csv
Processing:  one_carbon_pool_by_f_train_3035.csv
Processing:  one_carbon_pool_by_f_train_3699.csv
Processing:  one_carbon_pool_by_f_train_4046.csv
Processing:  one_carbon_pool_by_f_train_4656.csv
Processing:  one_carbon_pool_by_f_train_5988.csv
Processing:  one_carbon_pool_by_f_train_9250.csv
Processing:  other_types_of_o_gly_train_2394.csv
Processing:  other_types_of_o_gly_train_2494.csv
Processing:  other_types_of_o_gly_train_2893.csv
Processing:  other_types_of_o_gly_train_5221.csv
Processing:  other_types_o

Processing:  sphingolipid_metabol_train_5576.csv
Processing:  sphingolipid_metabol_train_6402.csv
Processing:  sphingolipid_metabol_train_7075.csv
Processing:  sphingolipid_metabol_train_7417.csv
Processing:  sphingolipid_metabol_train_7532.csv
Processing:  sphingolipid_metabol_train_7670.csv
Processing:  sphingolipid_metabol_train_8777.csv
Processing:  starch_and_sucrose_m_train_1473.csv
Processing:  starch_and_sucrose_m_train_2356.csv
Processing:  starch_and_sucrose_m_train_3953.csv
Processing:  starch_and_sucrose_m_train_4604.csv
Processing:  starch_and_sucrose_m_train_5326.csv
Processing:  starch_and_sucrose_m_train_5329.csv
Processing:  starch_and_sucrose_m_train_5671.csv
Processing:  starch_and_sucrose_m_train_5869.csv
Processing:  starch_and_sucrose_m_train_5925.csv
Processing:  starch_and_sucrose_m_train_8147.csv
Processing:  steroid_biosynthesis_train_1126.csv
Processing:  steroid_biosynthesis_train_3268.csv
Processing:  steroid_biosynthesis_train_4516.csv
Processing:  steroid

In [7]:
# Save the dataset to disk
torch.save(data_list, 'dataset.pt')