In [1]:
import pandas as pd

In [5]:
nodes_df = pd.read_csv('../KG_node_map_test.csv')
edges_df = pd.read_csv('../KG_edgelist_mask_test.csv')

In [8]:
node_indices = set(nodes_df['node_idx'])
edge_node_indices = set(edges_df['x_idx']).union(set(edges_df['y_idx']))

missing_nodes = node_indices - edge_node_indices
if missing_nodes:
    print(f"Warning - {len(missing_nodes)} nodes are missing from the edge list")
else:
    print("All nodes are present in the edge list")

All nodes are present in the edge list


In [11]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.utils import to_undirected

In [12]:
data = HeteroData()

In [14]:
node_types = nodes_df['node_type'].unique()
print(f"Node types: {node_types}")

Node types: ['gene/protein' 'effect/phenotype' 'disease' 'biological_process'
 'molecular_function' 'cellular_component' 'pathway']


In [16]:
node_type_to_idx = {ntype: idx for idx, ntype in enumerate(node_types)}
print(f"Node type to index: {node_type_to_idx}")

Node type to index: {'gene/protein': 0, 'effect/phenotype': 1, 'disease': 2, 'biological_process': 3, 'molecular_function': 4, 'cellular_component': 5, 'pathway': 6}


In [17]:
import numpy as np

In [18]:
node_feature = {}

for idx, row in nodes_df.iterrows():
    node_idx = row['node_idx']
    node_type = row['node_type']
    type_idx  = node_type_to_idx[node_type]

    # One-hot encoding of node type
    type_one_hot = np.zeros(len(node_types))
    type_one_hot[type_idx] = 1
    node_feature[node_idx] = type_one_hot  

In [21]:
print(f"Node feature for node 0: {node_feature[0]}")

Node feature for node 0: [1. 0. 0. 0. 0. 0. 0.]


In [30]:
print(f"Node feature for node 105000: {node_feature[105209]}")

Node feature for node 105000: [1. 0. 0. 0. 0. 0. 0.]


In [31]:
for node_type in node_types:
    # get node indices of this type
    indices = nodes_df[nodes_df['node_type'] == node_type]['node_idx'].values
    # get the features for these nodes
    features = np.array([node_feature[idx] for idx in indices])

    # add to HeteroData
    data[node_type].x = torch.tensor(features, dtype=torch.float)

    # map nodes indices to local indices for Pytorch Geometric
    mapping = {idx: i for i, idx in enumerate(indices)}

    # Store the mapping for edge indexing later
    nodes_df.loc[nodes_df['node_type'] == node_type, 'local_idx'] = range(len(indices))
    nodes_df.loc[nodes_df['node_type'] == node_type, 'local_mapping'] = [mapping[idx] for idx in indices]

In [34]:
print(f"Node feature for node 0: {data['gene/protein'].x[0]}")

Node feature for node 0: tensor([1., 0., 0., 0., 0., 0., 0.])


In [36]:
edge_relations = edges_df['full_relation'].unique()
relation_to_edge_type = {rel: idx for idx, rel in enumerate(edge_relations)}
print(relation_to_edge_type)

{'gene/protein;protein_protein;gene/protein': 0, 'effect/phenotype;phenotype_protein;gene/protein': 1, 'effect/phenotype;phenotype_phenotype;effect/phenotype': 2, 'disease;disease_phenotype_negative;effect/phenotype': 3, 'disease;disease_phenotype_positive;effect/phenotype': 4, 'disease;disease_protein;gene/protein': 5, 'disease;disease_disease;disease': 6, 'biological_process;bioprocess_bioprocess;biological_process': 7, 'molecular_function;molfunc_molfunc;molecular_function': 8, 'cellular_component;cellcomp_cellcomp;cellular_component': 9, 'gene/protein;protein_molfunc;molecular_function': 10, 'gene/protein;protein_cellcomp;cellular_component': 11, 'gene/protein;protein_bioprocess;biological_process': 12, 'pathway;pathway_pathway;pathway': 13, 'gene/protein;protein_pathway;pathway': 14, 'gene/protein;protein_protein_rev;gene/protein': 15, 'effect/phenotype;phenotype_phenotype_rev;effect/phenotype': 16, 'disease;disease_disease_rev;disease': 17, 'biological_process;bioprocess_bioproce

In [37]:
for relation in edge_relations:
    # Get edges of this relation
    rel_edges = edges_df[edges_df['full_relation'] == relation]
    # Extract node types from the relation string
    src_type, _, dst_type = relation.split(';')
    # Get local indices for source and target nodes
    src_mapping = nodes_df[nodes_df['node_type'] == src_type].set_index('node_idx')['local_mapping']
    dst_mapping = nodes_df[nodes_df['node_type'] == dst_type].set_index('node_idx')['local_mapping']
    # Map global node indices to local indices
    src_local_idx = rel_edges['x_idx'].map(src_mapping).values
    dst_local_idx = rel_edges['y_idx'].map(dst_mapping).values
    # Remove edges with missing nodes (if any)
    valid_edges = (~np.isnan(src_local_idx)) & (~np.isnan(dst_local_idx))
    src_local_idx = src_local_idx[valid_edges].astype(int)
    dst_local_idx = dst_local_idx[valid_edges].astype(int)
    # Create edge index tensor
    edge_index = torch.tensor([src_local_idx, dst_local_idx], dtype=torch.long)
    # Add to HeteroData
    data[(src_type, relation, dst_type)].edge_index = edge_index

  edge_index = torch.tensor([src_local_idx, dst_local_idx], dtype=torch.long)


In [38]:
from torch_geometric.loader import NeighborLoader

In [39]:
# Define batch size and number of neighbors to sample
batch_size = 1024
num_neighbors = [10, 10]  # Number of neighbors to sample at each layer

In [41]:
# Create a dictionary mapping node type to the indices of nodes
node_indices = {}
for node_type in node_types:
    node_indices[node_type] = torch.arange(data[node_type].num_nodes)
    print(f"Number of {node_type} nodes: {data[node_type].num_nodes}")

Number of gene/protein nodes: 21610
Number of effect/phenotype nodes: 15874
Number of disease nodes: 21233
Number of biological_process nodes: 28642
Number of molecular_function nodes: 11169
Number of cellular_component nodes: 4176
Number of pathway nodes: 2516


In [71]:
# Create NeighborLoader for each node type
train_loader = NeighborLoader(
    data,
    input_nodes=('gene/protein', node_indices['gene/protein']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

# Create NeighborLoader for each node type
train_loader = NeighborLoader(
    data,
    input_nodes=('effect/phenotype', node_indices['effect/phenotype']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

train_loader = NeighborLoader(
    data,
    input_nodes=('disease', node_indices['disease']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

train_loader = NeighborLoader(
    data,
    input_nodes=('biological_process', node_indices['biological_process']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

train_loader = NeighborLoader(
    data,
    input_nodes=('molecular_function', node_indices['molecular_function']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

train_loader = NeighborLoader(
    data,
    input_nodes=('pathway', node_indices['pathway']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)



In [51]:
# Create NeighborLoader for each node type
train_loader_gene_protein = NeighborLoader(
    data,
    input_nodes=('gene/protein', node_indices['gene/protein']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [52]:
# Create NeighborLoader for each node type
train_loader_effect_phenotype = NeighborLoader(
    data,
    input_nodes=('effect/phenotype', node_indices['effect/phenotype']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [54]:
train_loader_disease = NeighborLoader(
    data,
    input_nodes=('disease', node_indices['disease']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [56]:
train_loader_biological_process = NeighborLoader(
    data,
    input_nodes=('biological_process', node_indices['biological_process']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [58]:
train_loader_molecular_function = NeighborLoader(
    data,
    input_nodes=('molecular_function', node_indices['molecular_function']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [60]:
train_loader_cellular_component = NeighborLoader(
    data,
    input_nodes=('cellular_component', node_indices['cellular_component']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [62]:
train_loader_pathway = NeighborLoader(
    data,
    input_nodes=('pathway', node_indices['pathway']),
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    shuffle=True
)

In [65]:
import torch.nn as nn
from torch_geometric.nn import HeteroConv, GATConv

class HeteroGAT(nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('gene/protein', 'protein_protein', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('effect/phenotype', 'phenotype_protein', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('effect/phenotype', 'phenotype_phenotype', 'effect/phenotype'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('disease', 'disease_phenotype_negative', 'effect/phenotype'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('disease', 'disease_phenotype_positive', 'effect/phenotype'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('disease', 'disease_protein', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('disease', 'disease_disease', 'disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('biological_process', 'bioprocess_bioprocess', 'biological_process'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('molecular_function', 'molfunc_molfunc', 'molecular_function'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('cellular_component', 'cellcomp_cellcomp', 'cellular_component'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'protein_molfunc', 'molecular_function'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'protein_cellcomp', 'cellular_component'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'protein_bioprocess', 'biological_process'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('pathway', 'pathway_pathway', 'pathway'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'protein_pathway', 'pathway'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'protein_protein_rev', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('effect/phenotype', 'phenotype_phenotype_rev', 'effect/phenotype'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('disease', 'disease_disease_rev', 'disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('biological_process', 'bioprocess_bioprocess_rev', 'biological_process'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('molecular_function', 'molfunc_molfunc_rev', 'molecular_function'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('cellular_component', 'cellcomp_cellcomp_rev', 'cellular_component'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('pathway', 'pathway_pathway_rev', 'pathway'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'phenotype_protein', 'effect/phenotype'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('effect/phenotype', 'disease_phenotype_negative', 'disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('effect/phenotype', 'disease_phenotype_positive', 'disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('gene/protein', 'disease_protein', 'disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('molecular_function', 'protein_molfunc', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('cellular_component', 'protein_cellcomp', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('biological_process', 'protein_bioprocess', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
                ('pathway', 'protein_pathway', 'gene/protein'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return x_dict

In [66]:
model = HeteroGAT(hidden_channels=64, out_channels=1, num_layers=2)

In [67]:
import torch.optim as optim

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

In [69]:
num_epochs = 10

In [72]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        # Move batch to device (GPU if available)
        batch = batch.to(device)
        # Get model predictions
        out = model(batch.x_dict, batch.edge_index_dict)
        # Extract predictions for gene nodes
        gene_preds = out['gene/protein'].squeeze()
        # Get labels for gene nodes
        gene_labels = batch['gene/protein'].y.float()
        # Compute loss
        loss = loss_fn(gene_preds, gene_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")

NameError: name 'device' is not defined