In [1]:
import torch
from collections import Counter
from collections import defaultdict
from torch_geometric.data import HeteroData

In [2]:
graph_path_name = "graph_biobert_128_dim"

In [3]:
data = torch.load(f"../data/{graph_path_name}.pt", weights_only=False)
print(data)

Data(
  x=[3908, 135],
  edge_index=[2, 8360],
  edge_attr=[8360, 13],
  node_names=[3908],
  node_to_idx={
    NDRG1=0,
    PIM1=1,
    BMPR1B=2,
    LINC01391=3,
    SCDAL=4,
    uc003fir=5,
    ZC3H12A=6,
    EPHB6=7,
    miR-27a-3p=8,
    MIR143=9,
    TGFbeta1=10,
    EIF4G1=11,
    MIR145=12,
    LINC00319=13,
    ZFAS2=14,
    CDH2=15,
    CDKN1A=16,
    RAC1=17,
    CADM1=18,
    snaR=19,
    lncRNA 00312=20,
    DCTN6=21,
    FAM83A=22,
    ATP8B=23,
    JUNB=24,
    KLF16=25,
    SUSAJ1=26,
    PARP9=27,
    SPRY4-IT1=28,
    SLC16A1-AS1=29,
    GAB1=30,
    FOSL2=31,
    AJAP1=32,
    FGF19=33,
    LERFS=34,
    SP3=35,
    SLC7A5=36,
    SERPINH1=37,
    DRAM1=38,
    COL1A2-AS1=39,
    RAD18=40,
    LINC00160=41,
    ZBTB5=42,
    LATS2=43,
    CCR6=44,
    CRCMSL=45,
    IGF2-AS=46,
    miR-146b-5p=47,
    miR-196a=48,
    SMOC1=49,
    RAPIA=50,
    STXBP5-AS1=51,
    LAMB3=52,
    miR-23b-3p=53,
    USP22=54,
    TAC1=55,
    miR-202-5p=56,
    miR-1273h-5p=57,
    SKIL

In [4]:
# Create mapping from node indices to node types
node_type_names = ["miRNA", "PCG", "TF", "lncRNA", "snoRNA", "circRNA"]
node_type_ids = torch.argmax(data.x[:, -7:], dim=1)
node_types = [node_type_names[i] for i in node_type_ids.tolist()]

# Count how many nodes belong to each type
node_type_counts = Counter(node_types)

# Print summary of node types
print("Node type counts:")
for t, c in node_type_counts.items():
    print(f"{t}: {c}")

Node type counts:
PCG: 1717
lncRNA: 1333
miRNA: 512
TF: 336
snoRNA: 3
circRNA: 7


In [5]:
### all nodes are 3908 nodes from all types...

# Group global indices by node type
node_type_to_indices = defaultdict(list)
for i, node_type in enumerate(node_types):
    node_type_to_indices[node_type].append(i)

# Map global node index to local index within each node type
global_to_local = {}
for node_type, indices in node_type_to_indices.items():
    for local_idx, global_idx in enumerate(indices):
        global_to_local[(node_type, global_idx)] = local_idx

In [6]:
# Initialize HeteroData
hetero_data = HeteroData()

# Add node features for each type
for node_type, indices in node_type_to_indices.items():
    indices_tensor = torch.tensor(indices, dtype=torch.long)
    hetero_data[node_type].x = data.x[
        indices_tensor, :-7
    ]  # Exclude one-hot type vector
    print(
        f"{node_type}: added {len(indices)} nodes with shape {hetero_data[node_type].x.shape}"
    )

PCG: added 1717 nodes with shape torch.Size([1717, 128])
lncRNA: added 1333 nodes with shape torch.Size([1333, 128])
miRNA: added 512 nodes with shape torch.Size([512, 128])
TF: added 336 nodes with shape torch.Size([336, 128])
snoRNA: added 3 nodes with shape torch.Size([3, 128])
circRNA: added 7 nodes with shape torch.Size([7, 128])


In [7]:
# Build edges per edge type with correct local indices
edge_groups = defaultdict(list)
edge_attrs = defaultdict(list)

for i in range(data.edge_index.shape[1]):
    src = data.edge_index[0, i].item()
    dst = data.edge_index[1, i].item()
    attr = data.edge_attr[i]

    src_type = node_types[src]
    dst_type = node_types[dst]

    edge_type = (src_type, "regulates", dst_type)
    edge_groups[edge_type].append((src, dst))
    edge_attrs[edge_type].append(attr)

In [8]:
# Add edge_index and edge_attr to HeteroData
for edge_type, edges in edge_groups.items():
    src_type, _, dst_type = edge_type
    src_list = [global_to_local[(src_type, s)] for s, _ in edges]
    dst_list = [global_to_local[(dst_type, d)] for _, d in edges]

    edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
    edge_attr = torch.stack(edge_attrs[edge_type])

    hetero_data[edge_type].edge_index = edge_index
    hetero_data[edge_type].edge_attr = edge_attr

    print(
        f"Added edge type {edge_type}: {edge_index.shape[1]} edges with attr shape {edge_attr.shape}"
    )

Added edge type ('lncRNA', 'regulates', 'miRNA'): 1088 edges with attr shape torch.Size([1088, 13])
Added edge type ('lncRNA', 'regulates', 'PCG'): 5029 edges with attr shape torch.Size([5029, 13])
Added edge type ('lncRNA', 'regulates', 'TF'): 1570 edges with attr shape torch.Size([1570, 13])
Added edge type ('PCG', 'regulates', 'miRNA'): 11 edges with attr shape torch.Size([11, 13])
Added edge type ('TF', 'regulates', 'lncRNA'): 303 edges with attr shape torch.Size([303, 13])
Added edge type ('lncRNA', 'regulates', 'lncRNA'): 34 edges with attr shape torch.Size([34, 13])
Added edge type ('miRNA', 'regulates', 'lncRNA'): 82 edges with attr shape torch.Size([82, 13])
Added edge type ('PCG', 'regulates', 'PCG'): 36 edges with attr shape torch.Size([36, 13])
Added edge type ('circRNA', 'regulates', 'miRNA'): 3 edges with attr shape torch.Size([3, 13])
Added edge type ('circRNA', 'regulates', 'PCG'): 9 edges with attr shape torch.Size([9, 13])
Added edge type ('PCG', 'regulates', 'lncRNA'

In [9]:
## 6 node types each node has 768 node features
# 18 edge types each edge has 13 edge attributes.
print(hetero_data)

HeteroData(
  PCG={ x=[1717, 128] },
  lncRNA={ x=[1333, 128] },
  miRNA={ x=[512, 128] },
  TF={ x=[336, 128] },
  snoRNA={ x=[3, 128] },
  circRNA={ x=[7, 128] },
  (lncRNA, regulates, miRNA)={
    edge_index=[2, 1088],
    edge_attr=[1088, 13],
  },
  (lncRNA, regulates, PCG)={
    edge_index=[2, 5029],
    edge_attr=[5029, 13],
  },
  (lncRNA, regulates, TF)={
    edge_index=[2, 1570],
    edge_attr=[1570, 13],
  },
  (PCG, regulates, miRNA)={
    edge_index=[2, 11],
    edge_attr=[11, 13],
  },
  (TF, regulates, lncRNA)={
    edge_index=[2, 303],
    edge_attr=[303, 13],
  },
  (lncRNA, regulates, lncRNA)={
    edge_index=[2, 34],
    edge_attr=[34, 13],
  },
  (miRNA, regulates, lncRNA)={
    edge_index=[2, 82],
    edge_attr=[82, 13],
  },
  (PCG, regulates, PCG)={
    edge_index=[2, 36],
    edge_attr=[36, 13],
  },
  (circRNA, regulates, miRNA)={
    edge_index=[2, 3],
    edge_attr=[3, 13],
  },
  (circRNA, regulates, PCG)={
    edge_index=[2, 9],
    edge_attr=[9, 13],
  },


In [10]:
# show metadata summary (node types and edge types)

print(len(hetero_data.node_types))
print("\nNode types:", hetero_data.node_types)

print(len(hetero_data.edge_types))
print("Edge types:", hetero_data.edge_types)

6

Node types: ['PCG', 'lncRNA', 'miRNA', 'TF', 'snoRNA', 'circRNA']
18
Edge types: [('lncRNA', 'regulates', 'miRNA'), ('lncRNA', 'regulates', 'PCG'), ('lncRNA', 'regulates', 'TF'), ('PCG', 'regulates', 'miRNA'), ('TF', 'regulates', 'lncRNA'), ('lncRNA', 'regulates', 'lncRNA'), ('miRNA', 'regulates', 'lncRNA'), ('PCG', 'regulates', 'PCG'), ('circRNA', 'regulates', 'miRNA'), ('circRNA', 'regulates', 'PCG'), ('PCG', 'regulates', 'lncRNA'), ('PCG', 'regulates', 'TF'), ('lncRNA', 'regulates', 'snoRNA'), ('circRNA', 'regulates', 'TF'), ('TF', 'regulates', 'miRNA'), ('TF', 'regulates', 'PCG'), ('TF', 'regulates', 'circRNA'), ('TF', 'regulates', 'TF')]


In [11]:
# Choose where to save the file
save_path = f"../data/{graph_path_name}_hetero.pt"

# Save the HeteroData object
torch.save(hetero_data, save_path)

print(f"HeteroData graph saved to: {save_path}")

HeteroData graph saved to: ../data/graph_biobert_128_dim_hetero.pt
