In [3]:
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, Linear 
from torch_geometric.data import HeteroData
import json 
import random
from sklearn.model_selection import train_test_split
import os

In [4]:
# device set to cuda otherwise cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [44]:
def activate_phenotypes(x, active_phenotype_indices, activation_value=1.0):
    x_activated = x.clone()
    x_activated[active_phenotype_indices] += activation_value
    return x_activated

In [45]:
def retrieve_subgraph(node_scores, num_nodes=2500):
    # Select top k nodes based on their scores
    _, top_indices = torch.topk(node_scores.squeeze(), num_nodes)
    return top_indices

In [46]:
def process_patient(model, graph, patient_data):
    # Combine all node features and edge indices
    x = torch.cat([graph.x_dict[node_type] for node_type in graph.node_types], dim=0)
    edge_index = torch.cat([graph.edge_index_dict[edge_type] for edge_type in graph.edge_types], dim=1)
    
    # Activate patient-specific phenotypes
    active_phenotypes = torch.tensor(patient_data['positive_phenotypes']['indices'], dtype=torch.long)
    x_activated = activate_phenotypes(x, active_phenotypes)
    
    # Get node scores from the model
    node_scores = model(x_activated, edge_index)
    
    # Retrieve subgraph
    subgraph_nodes = retrieve_subgraph(node_scores)
    
    return subgraph_nodes, node_scores

In [43]:
def check_activated_phenotypes(x, x_activated, active_phenotype_indices):
    print(f"Original x shape: {x.shape}")
    print(f"Activated x shape: {x_activated.shape}")
    print(f"Number of active phenotypes: {len(active_phenotype_indices)}")

    diff = x_activated - x
    
    diff_indices = torch.nonzero(torch.any(diff != 0, dim=1)).squeeze()
    diff_indices = sorted(diff_indices.tolist())
    active_phenotype_indices = sorted(active_phenotype_indices.tolist())
    
    print(f"Number of different nodes: {len(diff_indices)}")
    print(f"Indices of different nodes: {diff_indices}")
    print(f"Active phenotype indices: {active_phenotype_indices}")
    
    if active_phenotype_indices == diff_indices:
        print("\nSUCCESSFUL\nActive phenotype indices are equal to the diff indices")
        return True
    else:
        print("\nFAILED\nActive phenotype indices are not equal to the diff indices")
        extra_activated = set(diff_indices) - set(active_phenotype_indices)
        not_activated = set(active_phenotype_indices) - set(diff_indices)
        if extra_activated:
            print(f"Extra nodes activated: {extra_activated}")
        if not_activated:
            print(f"Nodes not activated: {not_activated}")
        return False

In [5]:
data = torch.load('./KGs/Shepherd_KG_with_pretrained_embeddings3.pt')
print("Graph loaded from 'Shepherd_KG_with_pretrained_embeddings3.pt'.")
print(data)

Graph loaded from 'Shepherd_KG_with_pretrained_embeddings3.pt'.
HeteroData(
  gene/protein={
    x=[21610, 2048],
    node_ids=[21610],
    node_type='gene/protein',
    node_name=[21610],
  },
  effect/phenotype={
    x=[15874, 2048],
    node_ids=[15874],
    node_type='effect/phenotype',
    node_name=[15874],
  },
  disease={
    x=[21233, 2048],
    node_ids=[21233],
    node_type='disease',
    node_name=[21233],
  },
  biological_process={
    x=[28642, 2048],
    node_ids=[28642],
    node_type='biological_process',
    node_name=[28642],
  },
  molecular_function={
    x=[11169, 2048],
    node_ids=[11169],
    node_type='molecular_function',
    node_name=[11169],
  },
  cellular_component={
    x=[4176, 2048],
    node_ids=[4176],
    node_type='cellular_component',
    node_name=[4176],
  },
  pathway={
    x=[2516, 2048],
    node_ids=[2516],
    node_type='pathway',
    node_name=[2516],
  },
  (gene/protein, protein_protein, gene/protein)={
    edge_index=[2, 321075],
  

In [6]:
def load_patient_data(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

In [7]:
patient_files = [
    f"./patient_subgraph_data/labeled_patient_data/patient_{i}_result.json" 
    for i in range(4000) 
    if os.path.exists(f"./patient_subgraph_data/labeled_patient_data/patient_{i}_result.json")
]

print(len(patient_files)) 

2135


In [8]:
train_files, test_files = train_test_split(patient_files, test_size=0.2, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

In [9]:
print(f"Train: {len(train_files)}, Validation: {len(val_files)}, Test: {len(test_files)}")

Train: 1366, Validation: 342, Test: 427


In [10]:
class GuidedGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.phenotype_embedding = nn.Embedding(in_channels, hidden_channels)
        
        
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, dropout=0.2)
        self.head_transform_1 = Linear(hidden_channels * num_heads, hidden_channels)

        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=num_heads, dropout=0.2)
        self.head_transform_2 = Linear(hidden_channels * num_heads, hidden_channels)

        self.conv3 = GATConv(hidden_channels, hidden_channels, heads=num_heads, dropout=0.2)
        self.head_transform_3 = Linear(hidden_channels * num_heads, hidden_channels)

        self.linear1 = Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, active_phenotypes):
        phenotype_emb = self.phenotype_embedding(active_phenotypes)
        x = x + phenotype_emb

        x = self.conv1(x, edge_index)
        x = self.head_transform_1(F.elu(x))

        x = self.conv2(x, edge_index)
        x = self.head_transform_2(F.elu(x))

        x = self.conv3(x, edge_index)
        x = self.head_transform_3(F.elu(x))

        return self.linear1(x) 


In [11]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [12]:
#print architecture and paramenters of the model
model = GuidedGNN(in_channels=2048, hidden_channels=2048, out_channels=512, num_heads=3, num_layers=3)
model = model.to(device)
print(f"Number of paramenters: {count_parameters(model)}")
print(model)

Number of paramenters: 80802304
GuidedGNN(
  (phenotype_embedding): Embedding(2048, 2048)
  (conv1): GATConv(2048, 2048, heads=3)
  (head_transform_1): Linear(6144, 2048, bias=True)
  (conv2): GATConv(2048, 2048, heads=3)
  (head_transform_2): Linear(6144, 2048, bias=True)
  (conv3): GATConv(2048, 2048, heads=3)
  (head_transform_3): Linear(6144, 2048, bias=True)
  (linear1): Linear(2048, 512, bias=True)
)


In [15]:
x_t = torch.cat([data.x_dict[node_type] for node_type in data.node_types], dim=0)
print(x_t.shape)


torch.Size([105220, 2048])


In [16]:
data.x_dict

{'gene/protein': tensor([[-1.5311e+00, -7.5088e-01, -6.5134e-01,  ...,  5.3504e-01,
           1.5169e-01,  1.7642e-01],
         [-2.7281e-01, -1.5860e-03, -1.3548e+00,  ..., -7.5338e-01,
          -1.0343e+00,  1.8137e+00],
         [-1.9237e+00, -2.3205e-01, -2.1809e-02,  ...,  1.0526e-01,
          -4.0546e-01, -4.1800e-01],
         ...,
         [ 3.4648e-01,  1.1551e+00, -9.9104e-02,  ..., -1.3863e+00,
          -6.2624e-03,  1.7522e+00],
         [ 5.3702e-01, -3.1417e-01, -5.7738e-01,  ...,  3.9759e-01,
           1.0083e+00, -8.5591e-02],
         [ 4.6956e-01, -1.2878e+00, -1.7059e+00,  ..., -1.1443e+00,
           3.9692e-01, -2.2757e+00]]),
 'effect/phenotype': tensor([[ 0.9057, -0.4680, -1.3922,  ..., -0.4763,  0.4993, -2.0206],
         [ 1.2355, -0.9945, -0.1036,  ...,  0.1707, -1.7682,  0.6473],
         [ 2.6220,  0.8790,  1.7658,  ...,  1.4747, -0.4595, -0.2842],
         ...,
         [-1.0179, -0.0909, -0.2129,  ...,  0.5464,  0.2933,  1.2364],
         [ 1.4223,  

In [17]:
data.edge_index_dict

{('gene/protein',
  'protein_protein',
  'gene/protein'): tensor([[    0,     1,     2,  ...,  3839,  1600, 10974],
         [ 8889,  2798,  5646,  ...,   226,  5680, 12723]]),
 ('effect/phenotype',
  'phenotype_protein',
  'gene/protein'): tensor([[14012, 14012, 14012,  ..., 14681, 14469, 19686],
         [ 7097,  5230,  4315,  ...,  9851, 13297, 51965]]),
 ('effect/phenotype',
  'phenotype_phenotype',
  'effect/phenotype'): tensor([[14378, 14682, 14683,  ..., 14594, 14594, 14594],
         [14740, 60100, 14684,  ..., 14744, 64415, 14724]]),
 ('disease',
  'disease_phenotype_negative',
  'effect/phenotype'): tensor([[19687, 19687, 19687,  ..., 20602, 20602, 20603],
         [14051, 15274, 14577,  ..., 14430, 14153, 15137]]),
 ('disease',
  'disease_phenotype_positive',
  'effect/phenotype'): tensor([[20604, 19687, 19687,  ..., 27911, 20603, 22582],
         [16985, 14066, 14054,  ..., 14528, 60713, 15027]]),
 ('disease',
  'disease_protein',
  'gene/protein'): tensor([[22008, 24601, 2

In [18]:
edge_index_t = torch.cat([data.edge_index_dict[edge_type] for edge_type in data.edge_types], dim=1)
print(edge_index_t.shape)

torch.Size([2, 2190938])


In [19]:
active_phenotypes_t = torch.tensor(patient_data['positive_phenotypes']['indices'], dtype=torch.long)
print(patient_data['patient_id'])
print(active_phenotypes_t)


347
tensor([63501, 14162, 16074, 61172, 63139, 63489, 16000, 18088, 15410, 14012,
        14451, 69673, 16981])


In [20]:
def activate_phenotypes(x, active_phenotype_indices, activation_value=1.0):
    # Assuming 'effect/phenotype' nodes are at the beginning of x
    print(f"x shape: {x.shape}")
    print(f"Indices of active phenotypes: {active_phenotype_indices}")
    x_activated = x.clone()
    print(f"X activated shape: {x_activated.shape}")
    x_activated[active_phenotype_indices] += activation_value
    print(f"X activated shape after adding activation value: {x_activated.shape}")
    return x_activated

In [21]:
x_activated_t = activate_phenotypes(x_t, active_phenotypes_t)

x shape: torch.Size([105220, 2048])
Indices of active phenotypes: tensor([63501, 14162, 16074, 61172, 63139, 63489, 16000, 18088, 15410, 14012,
        14451, 69673, 16981])
X activated shape: torch.Size([105220, 2048])
X activated shape after adding activation value: torch.Size([105220, 2048])


In [22]:
print(x_activated_t.shape)

torch.Size([105220, 2048])


In [23]:
# do a function to check which indices are different from x_active_t and x_t
def check_activated_phenotypes(x, x_activated):
    print(f"x shape: {x.shape}")
    print(f"x_activated shape: {x_activated.shape}")
    diff_indices = torch.nonzero(x_activated - x)
    return diff_indices

In [24]:
check_activated_phenotypes(x_t, x_activated_t)

x shape: torch.Size([105220, 2048])
x_activated shape: torch.Size([105220, 2048])


tensor([[14012,     0],
        [14012,     1],
        [14012,     2],
        ...,
        [69673,  2045],
        [69673,  2046],
        [69673,  2047]])

In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, Linear 
from torch_geometric.data import HeteroData
import json 
import random
from sklearn.model_selection import train_test_split
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

data = torch.load('./KGs/Shepherd_KG_with_pretrained_embeddings3.pt')
print("Graph loaded from 'Shepherd_KG_with_pretrained_embeddings3.pt'.")
print(data)

def activate_phenotypes(x, active_phenotype_indices, activation_value=1.0):
    x_activated = x.clone()
    x_activated[active_phenotype_indices] += activation_value
    return x_activated

def retrieve_subgraph(node_scores, num_nodes=2500):
    # Select top k nodes based on their scores
    _, top_indices = torch.topk(node_scores.squeeze(), num_nodes)
    return top_indices

def process_patient(model, graph, patient_data):
    # Combine all node features and edge indices
    x = torch.cat([graph.x_dict[node_type] for node_type in graph.node_types], dim=0)
    edge_index = torch.cat([graph.edge_index_dict[edge_type] for edge_type in graph.edge_types], dim=1)
    
    # Activate patient-specific phenotypes
    active_phenotypes = torch.tensor(patient_data['positive_phenotypes']['indices'], dtype=torch.long)
    x_activated = activate_phenotypes(x, active_phenotypes)
    
    # Get node scores from the model
    node_scores = model(x_activated, edge_index)
    
    # Retrieve subgraph
    subgraph_nodes = retrieve_subgraph(node_scores)
    
    return subgraph_nodes, node_scores


class GuidedGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.phenotype_embedding = nn.Embedding(in_channels, hidden_channels)
        
        
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, dropout=0.2)
        self.head_transform_1 = Linear(hidden_channels * num_heads, hidden_channels)

        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=num_heads, dropout=0.2)
        self.head_transform_2 = Linear(hidden_channels * num_heads, hidden_channels)

        self.conv3 = GATConv(hidden_channels, hidden_channels, heads=num_heads, dropout=0.2)
        self.head_transform_3 = Linear(hidden_channels * num_heads, hidden_channels)

        self.linear1 = Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, active_phenotypes):
        phenotype_emb = self.phenotype_embedding(active_phenotypes)
        x = x + phenotype_emb

        x = self.conv1(x, edge_index)
        x = self.head_transform_1(F.elu(x))

        x = self.conv2(x, edge_index)
        x = self.head_transform_2(F.elu(x))

        x = self.conv3(x, edge_index)
        x = self.head_transform_3(F.elu(x))

        return self.linear1(x) 

# Training loop
model = GuidedGNN(in_channels=2048, hidden_channels=2048,  heads=3, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    for patient_file in train_files:
        patient_data = load_patient_data(patient_file)
        optimizer.zero_grad()
        
        subgraph_nodes, node_scores = process_patient(model, graph, patient_data)
        
        # Check if true gene is in the retrieved subgraph
        true_gene_idx = patient_data['true_gene']['index']
        true_gene_in_subgraph = true_gene_idx in subgraph_nodes
        
        # Compute loss
        if true_gene_in_subgraph:
            # Encourage high score for true gene
            loss = F.binary_cross_entropy_with_logits(node_scores[true_gene_idx], torch.tensor([1.0]))
        else:
            # Penalize model for not including true gene
            loss = F.binary_cross_entropy_with_logits(node_scores[true_gene_idx], torch.tensor([1.0])) + \
                   F.binary_cross_entropy_with_logits(node_scores[subgraph_nodes].max(), torch.tensor([0.0]))
        
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch}, Patient {patient_data['patient_id']}")
        print(f"True gene in subgraph: {true_gene_in_subgraph}")
        print(f"Loss: {loss.item()}")
        print("---")