In [30]:
import torch
import numpy as np
import pandas as pd
from torch_geometric.data import HeteroData

In [3]:
## load hetero full undirected data
hetero_data = torch.load('data/combined_dbs_heteroGraph.pt')
print(hetero_data)

HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 7635],
    edge_attr=[7635, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 148992],
    edge_attr=[148992, 1],
  }
)


In [17]:
from torch_geometric.transforms import RandomLinkSplit

# split the data with RandomLinkSplit

split_transform = RandomLinkSplit(
    num_val=0.1,  # 10% validation
    num_test=0.1,  # 10% test
    is_undirected=False,  
    edge_types=[("lncRNA", "interacts", "protein"), ("protein", "interacts", "protein")],
    rev_edge_types=None,  # 
    disjoint_train_ratio=0.3,  
    add_negative_train_samples=False,  
)

train_data, val_data, test_data = split_transform(hetero_data)

In [22]:
print('Original Data')
print('-----------')
print(hetero_data)

print('Train Data')
print('-----------')
print(train_data)

print('Validation Data')
print('-----------')
print(val_data)


print('Test Data')
print('-----------')
print(test_data)

Original Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 7635],
    edge_attr=[7635, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 148992],
    edge_attr=[148992, 1],
  }
)
Train Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 4277],
    edge_attr=[4277, 4],
    edge_label=[1832],
    edge_label_index=[2, 1832],
  },
  (protein, interacts, protein)={
    edge_index=[2, 83436],
    edge_attr=[83436, 1],
    edge_label=[35758],
    edge_label_index=[2, 35758],
  }
)
Validation Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 6109],
    edge_attr=[6109, 4],
    edge_label=[1526],
    edge_label_index=[2, 1526],
  },
  (protein, interacts, protein)={
    edge_index=[2, 119194],
    edge_attr=[119194, 1],
    edge_label

In [23]:
# split data manually ..

# Define a function to split edges randomly into train / validation / test
# You can control how much goes to training (default = 80%) and validation (default = 10%)

def split_edges(edge_index, edge_attr, train_ratio=0.8, val_ratio=0.1, seed=42):
    
    # Set random seed for reproducibility — so every run gives same result.
    torch.manual_seed(seed)
    
    # Count the number of edges (number of columns in edge_index)
    num_edges = edge_index.size(1)
    
    #  Create a random permutation of edge indices → this shuffles the edges
    indices = torch.randperm(num_edges)

    #  Compute where to "cut" the shuffled list to get train, val, and test
    train_cutoff = int(train_ratio * num_edges)
    val_cutoff = int((train_ratio + val_ratio) * num_edges)

    # Use slicing to divide the indices into:training indices, validation indices and test indices
    train_idx = indices[:train_cutoff]
    val_idx = indices[train_cutoff:val_cutoff]
    test_idx = indices[val_cutoff:]
    

    # Return a dictionary that contains :
    # edge_index and edge_attr for training
    # edge_index and edge_attr for validation
    # edge_index and edge_attr for test

    return {
        'train': (edge_index[:, train_idx], edge_attr[train_idx]),
        'val': (edge_index[:, val_idx], edge_attr[val_idx]),
        'test': (edge_index[:, test_idx], edge_attr[test_idx]),
    }


In [24]:
# Split the lncRNA → protein edges into training, validation, and test sets
# using the edge_index and edge_attr stored in the HeteroData graph

split_lnc_protein = split_edges(
    hetero_data['lncRNA', 'interacts', 'protein'].edge_index,
    hetero_data['lncRNA', 'interacts', 'protein'].edge_attr
)

In [25]:
# Split the protein ↔ protein edges into training, validation, and test sets
# using the edge_index and edge_attr from the HeteroData graph

split_protein_protein = split_edges(
    hetero_data['protein', 'interacts', 'protein'].edge_index,
    hetero_data['protein', 'interacts', 'protein'].edge_attr
)

In [26]:
# Define a function to build a HeteroData graph from a given edge split
def build_graph_from_split(node_feats, lnc_split, ppi_split):
    
    # Create an empty heterogeneous graph
    g = HeteroData()

    # Assign node features for lncRNA nodes
    g['lncRNA'].x = node_feats['lncRNA']
    # Assign node features for protein nodes
    g['protein'].x = node_feats['protein']

    # Add lncRNA → protein edge_index and edge_attr from the split
    g['lncRNA', 'interacts', 'protein'].edge_index = lnc_split[0]
    g['lncRNA', 'interacts', 'protein'].edge_attr = lnc_split[1]

    # Add protein ↔ protein edge_index and edge_attr from the split
    g['protein', 'interacts', 'protein'].edge_index = ppi_split[0]
    g['protein', 'interacts', 'protein'].edge_attr = ppi_split[1]

    # Return the constructed HeteroData graph
    return g

# Collect the node features from the existing full data
node_feats = {
    'lncRNA': hetero_data['lncRNA'].x,
    'protein': hetero_data['protein'].x
}

# Build the training graph using training edges for both lncRNA–protein and protein–protein
train_data = build_graph_from_split(node_feats, split_lnc_protein['train'], split_protein_protein['train'])

# Build the validation graph using validation edges
val_data   = build_graph_from_split(node_feats, split_lnc_protein['val'],   split_protein_protein['val'])

# Build the test graph using test edges
test_data  = build_graph_from_split(node_feats, split_lnc_protein['test'],  split_protein_protein['test'])

In [27]:
print('Original Data')
print('-----------')
print(hetero_data)

print('Train Data')
print('-----------')
print(train_data)

print('Validation Data')
print('-----------')
print(val_data)


print('Test Data')
print('-----------')
print(test_data)

Original Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 7635],
    edge_attr=[7635, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 148992],
    edge_attr=[148992, 1],
  }
)
Train Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 6108],
    edge_attr=[6108, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 119193],
    edge_attr=[119193, 1],
  }
)
Validation Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 763],
    edge_attr=[763, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 14899],
    edge_attr=[14899, 1],
  }
)
Test Data
-----------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 764],
    edge_attr=[764, 4],
  },
  

In [None]:
"""
in the original data we have:
------------------------------
7635 lncRNA, interacts, protein 
148992 protein, interacts, protein 

and after splitting we have:
-----------------------------
6108  in Train + 763   in validation + 764   in Test which  = = 7635   (lncRNA,  interacts, protein) in original graph
119193 in Train + 14899 in Validation + 14900 in Test which = = 148992 (protein, interacts, protein) in original graph

Now we can use these splits
"""


In [15]:
# Define a helper function to print summary information about a HeteroData graph

def print_graph_info(graph, name):
    # Print the name or label of the graph (e.g., "Train", "Validation", "Test")
    print(f"\nInfo for: {name}")
    print("*" * 26)
    
    # Print the shapes of the node feature matrices
    print("Node features:")
    print("-" * 20)
    print("  lncRNA:", graph['lncRNA'].x.shape)
    print("  protein:", graph['protein'].x.shape)
    
    print("\n Edge features:")
    print("-" * 20)

     # Print information about the lncRNA → protein edge type
    print("lncRNA → protein:")
    print("  edge_index:", graph['lncRNA', 'interacts', 'protein'].edge_index.shape)
    print("  edge_attr :", graph['lncRNA', 'interacts', 'protein'].edge_attr.shape)
    
    # Print information about the protein ↔ protein edge type
    print("protein ↔ protein:")
    print("  edge_index:", graph['protein', 'interacts', 'protein'].edge_index.shape)
    print("  edge_attr :", graph['protein', 'interacts', 'protein'].edge_attr.shape)


In [16]:
print_graph_info(train_data, "TRAIN Graph")
print_graph_info(val_data, "VALIDATION Graph")
print_graph_info(test_data, "TEST Graph")


Info for: TRAIN Graph
**************************
Node features:
--------------------
  lncRNA: torch.Size([1269, 2])
  protein: torch.Size([11585, 2])

 Edge features:
--------------------
lncRNA → protein:
  edge_index: torch.Size([2, 6108])
  edge_attr : torch.Size([6108, 4])
protein ↔ protein:
  edge_index: torch.Size([2, 119193])
  edge_attr : torch.Size([119193, 1])

Info for: VALIDATION Graph
**************************
Node features:
--------------------
  lncRNA: torch.Size([1269, 2])
  protein: torch.Size([11585, 2])

 Edge features:
--------------------
lncRNA → protein:
  edge_index: torch.Size([2, 763])
  edge_attr : torch.Size([763, 4])
protein ↔ protein:
  edge_index: torch.Size([2, 14899])
  edge_attr : torch.Size([14899, 1])

Info for: TEST Graph
**************************
Node features:
--------------------
  lncRNA: torch.Size([1269, 2])
  protein: torch.Size([11585, 2])

 Edge features:
--------------------
lncRNA → protein:
  edge_index: torch.Size([2, 764])
  edge_

In [None]:
## When we achieved a very high score on the link prediction task and suspect on the result we tried to examine everything
## and srart from splitting data .. 

In [28]:
## check if there are overlap edges between the train , val and test sets
def check_edge_overlap(split1, split2, name1, name2, edge_type):
    e1 = set(map(tuple, split1[0].t().tolist()))
    e2 = set(map(tuple, split2[0].t().tolist()))

    overlap = e1.intersection(e2)
    print(f"[{edge_type}] Overlap between {name1} and {name2}: {len(overlap)} edges")
    if len(overlap) > 0:
        print(" WARNING: Data leakage detected!")


In [29]:
# lncRNA → protein
check_edge_overlap(split_lnc_protein['train'], split_lnc_protein['val'], 'train', 'val', 'lncRNA→protein')
check_edge_overlap(split_lnc_protein['train'], split_lnc_protein['test'], 'train', 'test', 'lncRNA→protein')
check_edge_overlap(split_lnc_protein['val'], split_lnc_protein['test'], 'val', 'test', 'lncRNA→protein')

# protein ↔ protein
check_edge_overlap(split_protein_protein['train'], split_protein_protein['val'], 'train', 'val', 'protein↔protein')
check_edge_overlap(split_protein_protein['train'], split_protein_protein['test'], 'train', 'test', 'protein↔protein')
check_edge_overlap(split_protein_protein['val'], split_protein_protein['test'], 'val', 'test', 'protein↔protein')


[lncRNA→protein] Overlap between train and val: 166 edges
[lncRNA→protein] Overlap between train and test: 173 edges
[lncRNA→protein] Overlap between val and test: 39 edges
[protein↔protein] Overlap between train and val: 22 edges
[protein↔protein] Overlap between train and test: 27 edges
[protein↔protein] Overlap between val and test: 2 edges


In [None]:
## our suspiction is true there are data leakage we try to split again carefully ..

In [37]:
import torch
import pandas as pd
import numpy as np

def clean_and_split_edges(edge_index, edge_attr, train_ratio=0.8, val_ratio=0.1, seed=42):
    
    # Set random seed for reproducibility
    torch.manual_seed(seed)

    # Separate source and destination node indices
    i = edge_index[0]
    j = edge_index[1]
    
    # Make edge direction consistent by sorting node pairs
    i_min = torch.min(i, j)
    i_max = torch.max(i, j)

     # Combine sorted edges into shape [num_edges, 2]
    unified_edges = torch.stack([i_min, i_max], dim=0).t()  # shape: [num_edges, 2]

    
    df = pd.DataFrame(unified_edges.tolist(), columns=['src', 'dst'])
    for k in range(edge_attr.size(1)):
        df[f'attr_{k}'] = edge_attr[:, k]
        
    # Remove duplicate edges (based on src–dst) and reset index
    df = df.drop_duplicates(subset=['src', 'dst']).reset_index(drop=True)

    # Shuffle
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)

    # Split
    num_total = len(df)
    train_end = int(train_ratio * num_total)
    val_end = int((train_ratio + val_ratio) * num_total)

    splits = {}
    for name, start, end in zip(['train', 'val', 'test'], [0, train_end, val_end], [train_end, val_end, num_total]):
        part = df.iloc[start:end]
        edge_index = torch.tensor(part[['src', 'dst']].values.T, dtype=torch.long)
        edge_attr = torch.tensor(part[[col for col in part.columns if col.startswith('attr_')]].values, dtype=torch.float)
        splits[name] = (edge_index, edge_attr)

    return splits


In [38]:
# lncRNA–protein
split_lnc_protein = clean_and_split_edges(
    hetero_data['lncRNA', 'interacts', 'protein'].edge_index,
    hetero_data['lncRNA', 'interacts', 'protein'].edge_attr
)

# protein–protein
split_protein_protein = clean_and_split_edges(
    hetero_data['protein', 'interacts', 'protein'].edge_index,
    hetero_data['protein', 'interacts', 'protein'].edge_attr
)


In [39]:
# lncRNA → protein
check_edge_overlap(split_lnc_protein['train'], split_lnc_protein['val'], 'train', 'val', 'lncRNA→protein')
check_edge_overlap(split_lnc_protein['train'], split_lnc_protein['test'], 'train', 'test', 'lncRNA→protein')
check_edge_overlap(split_lnc_protein['val'], split_lnc_protein['test'], 'val', 'test', 'lncRNA→protein')

# protein ↔ protein
check_edge_overlap(split_protein_protein['train'], split_protein_protein['val'], 'train', 'val', 'protein↔protein')
check_edge_overlap(split_protein_protein['train'], split_protein_protein['test'], 'train', 'test', 'protein↔protein')
check_edge_overlap(split_protein_protein['val'], split_protein_protein['test'], 'val', 'test', 'protein↔protein')


[lncRNA→protein] Overlap between train and val: 0 edges
[lncRNA→protein] Overlap between train and test: 0 edges
[lncRNA→protein] Overlap between val and test: 0 edges
[protein↔protein] Overlap between train and val: 0 edges
[protein↔protein] Overlap between train and test: 0 edges
[protein↔protein] Overlap between val and test: 0 edges


In [None]:
## now we trust these splits .. we will make graph for each one and save it.. 

In [40]:
# Collect the node features from the existing full data
node_feats = {
    'lncRNA': hetero_data['lncRNA'].x,
    'protein': hetero_data['protein'].x
}

# Build the training graph using training edges for both lncRNA–protein and protein–protein
train_data = build_graph_from_split(node_feats, split_lnc_protein['train'], split_protein_protein['train'])

# Build the validation graph using validation edges
val_data   = build_graph_from_split(node_feats, split_lnc_protein['val'],   split_protein_protein['val'])

# Build the test graph using test edges
test_data  = build_graph_from_split(node_feats, split_lnc_protein['test'],  split_protein_protein['test'])

In [41]:
# Define a helper function to print summary information about a HeteroData graph

def print_graph_info(graph, name):
    # Print the name or label of the graph (e.g., "Train", "Validation", "Test")
    print(f"\nInfo for: {name}")
    print("*" * 26)
    
    # Print the shapes of the node feature matrices
    print("Node features:")
    print("-" * 20)
    print("  lncRNA:", graph['lncRNA'].x.shape)
    print("  protein:", graph['protein'].x.shape)
    
    print("\n Edge features:")
    print("-" * 20)

     # Print information about the lncRNA → protein edge type
    print("lncRNA → protein:")
    print("  edge_index:", graph['lncRNA', 'interacts', 'protein'].edge_index.shape)
    print("  edge_attr :", graph['lncRNA', 'interacts', 'protein'].edge_attr.shape)
    
    # Print information about the protein ↔ protein edge type
    print("protein ↔ protein:")
    print("  edge_index:", graph['protein', 'interacts', 'protein'].edge_index.shape)
    print("  edge_attr :", graph['protein', 'interacts', 'protein'].edge_attr.shape)


In [42]:
print_graph_info(train_data, "TRAIN Graph")
print_graph_info(val_data, "VALIDATION Graph")
print_graph_info(test_data, "TEST Graph")


Info for: TRAIN Graph
**************************
Node features:
--------------------
  lncRNA: torch.Size([1269, 2])
  protein: torch.Size([11585, 2])

 Edge features:
--------------------
lncRNA → protein:
  edge_index: torch.Size([2, 5119])
  edge_attr : torch.Size([5119, 4])
protein ↔ protein:
  edge_index: torch.Size([2, 59544])
  edge_attr : torch.Size([59544, 1])

Info for: VALIDATION Graph
**************************
Node features:
--------------------
  lncRNA: torch.Size([1269, 2])
  protein: torch.Size([11585, 2])

 Edge features:
--------------------
lncRNA → protein:
  edge_index: torch.Size([2, 640])
  edge_attr : torch.Size([640, 4])
protein ↔ protein:
  edge_index: torch.Size([2, 7443])
  edge_attr : torch.Size([7443, 1])

Info for: TEST Graph
**************************
Node features:
--------------------
  lncRNA: torch.Size([1269, 2])
  protein: torch.Size([11585, 2])

 Edge features:
--------------------
lncRNA → protein:
  edge_index: torch.Size([2, 640])
  edge_attr

In [45]:
## save graph.. 
torch.save(train_data, 'data/directed_train_graph.pt')
torch.save(val_data,   'data/directed_val_graph.pt')
torch.save(test_data,  'data/directed_test_graph.pt')