# Imports

In [1]:
# imports
!pip install torch torch-geometric scikit-learn
!pip install optuna

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  

In [2]:
import pickle
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
# import torch geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.data import Data, DataLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GraphConv
# from torch_geometric.nn import MessagePassing
# from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
# from torch_geometric.utils import degree, to_dense_adj, to_dense_batch, to_undirected
# from torch_geometric.utils import from_networkx, to_networkx
# from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
# from torch_geometric.utils import dropout_adj, to_undirected, add_self_loops
# from torch_geometric.utils import to_dense_adj, to_dense_batch, dense_to_sparse

# print versions
print("torch version: ", torch.__version__)
print("torch geometric version: ", torch_geometric.__version__)
print("torch cuda version: ", torch.version.cuda)
print("torch cuda available: ", torch.cuda.is_available())
# print("torch cuda device count: ", torch.cuda.device_count())
# print("torch cuda current device name: ", torch.cuda.get_device_name(torch.cuda.current_device()))

# print versions
print(f"Pandas version: {pd.__version__}")
print(f"Numpy version: {np.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"NetworkX version: {nx.__version__}")

# set pandas display options to show all columns and rows without truncation
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
print("\nRemoved truncation of columns")


torch version:  2.6.0+cu124
torch geometric version:  2.6.1
torch cuda version:  12.4
torch cuda available:  True
Pandas version: 2.2.2
Numpy version: 2.0.2
Matplotlib version: 3.10.0
NetworkX version: 3.4.2

Removed truncation of columns


# Load data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
%ls

02-create-hetero-network-graph.ipynb  max_20000_nodes_similarity.gpickle
03-rgcn-model.ipynb                   max_22793_nodes_graph.gpickle
graph_max_100_nodes.html              max_22793_nodes_similarity.gpickle
hgnn_hr_ai.ipynb                      max_24370_nodes_graph.gpickle
max_100_nodes_graph.gpickle           max_24370_nodes_similarity.gpickle
max_100_nodes_similarity.gpickle      rgcn_model.pth
max_20000_nodes_graph.gpickle


In [5]:
%cd drive/MyDrive/AI_in_Healthcare/high_risk_project/3_approach_HGNN/

/content/drive/MyDrive/AI_in_Healthcare/high_risk_project/3_approach_HGNN


## Load networkx graph

In [20]:
# load the networkx graph
def load_graph(graph_path):
    with open(graph_path, 'rb') as f:
        graph = pickle.load(f)
    return graph

# CURR_DIR_PATH = os.getcwd()
# PICKLE_GRAPH_FILE_NAME = f"max_20000_nodes_graph.gpickle"
# PICKLE_GRAPH_FILE_NAME = os.path.abspath(PICKLE_GRAPH_FILE_NAME)

# print(f"Loading graph from {PICKLE_GRAPH_FILE_NAME}...")
ntx_graph = load_graph('max_24370_nodes_graph.gpickle')
print(f"Graph loaded. Number of nodes: {ntx_graph.number_of_nodes()}, number of edges: {ntx_graph.number_of_edges()}")


Graph loaded. Number of nodes: 24370, number of edges: 260505


24370

In [21]:
# count nodes of type patient, diagnosis and procedure
def count_node_types(graph):
    node_types = {}
    for node in graph.nodes():
        node_type = graph.nodes[node]['type']
        if node_type not in node_types:
            node_types[node_type] = 0
        node_types[node_type] += 1
    return node_types

node_types = count_node_types(ntx_graph)
print(f"Node counts: {node_types}")

# count edges named had_procedure, has_diagnosis
def count_edge_types(graph):
    edge_types = {}
    for u, v, key in graph.edges(keys=True):
        edge_type = graph.edges[u, v, key]['relation']
        if edge_type not in edge_types:
            edge_types[edge_type] = 0
        edge_types[edge_type] += 1
    return edge_types

edge_types = count_edge_types(ntx_graph)
print(f"Edge counts: {edge_types}")

Node counts: {'patient': 17678, 'diagnosis': 5184, 'procedure': 1508}
Edge counts: {'has_diagnosis': 189708, 'has_procedure': 70797}


## Load patient node similarity dataframe

In [22]:
# load the patient similarity dataframe
def load_patient_similarity_df(df_path):
    with open(df_path, 'rb') as f:
        df = pickle.load(f)
    return df

CURR_DIR_PATH = os.getcwd()
PATIENT_SIMILARITY_DF_FILE_NAME = f"max_24370_nodes_similarity.gpickle"
PATIENT_SIMILARITY_DF_FILE_NAME = os.path.abspath(PATIENT_SIMILARITY_DF_FILE_NAME)

print(f"Loading patient similarity dataframe from {PATIENT_SIMILARITY_DF_FILE_NAME}...")
patient_similarity_df = load_patient_similarity_df("max_24370_nodes_similarity.gpickle")
print(f"Patient similarity dataframe loaded. Number of rows: {patient_similarity_df.shape[0]}")

Loading patient similarity dataframe from /content/drive/MyDrive/AI_in_Healthcare/high_risk_project/3_approach_HGNN/max_24370_nodes_similarity.gpickle...
Patient similarity dataframe loaded. Number of rows: 2713620


In [23]:
patient_similarity_df.head(5)

  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,patient1,patient2,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
17680,patient-31182,patient-72623,31182,72623,0.333252,True,False,True
17777,patient-31182,patient-28675,31182,28675,0.300049,True,False,True
17810,patient-31182,patient-70605,31182,70605,0.3125,True,False,True
17877,patient-31182,patient-17208,31182,17208,0.333252,False,False,True
17967,patient-31182,patient-27351,31182,27351,0.353027,True,True,True


In [24]:
# rename column patient1 to source and patient2 to target
patient_similarity_df.rename(columns={'patient1': 'source_node', 'patient2': 'target_node'}, inplace=True)
patient_similarity_df.head(5)

  has_large_values = (abs_vals > 1e6).any()


Unnamed: 0,source_node,target_node,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
17680,patient-31182,patient-72623,31182,72623,0.333252,True,False,True
17777,patient-31182,patient-28675,31182,28675,0.300049,True,False,True
17810,patient-31182,patient-70605,31182,70605,0.3125,True,False,True
17877,patient-31182,patient-17208,31182,17208,0.333252,False,False,True
17967,patient-31182,patient-27351,31182,27351,0.353027,True,True,True


# Create R-GCN and train the model

### DEBUG

In [25]:
# Get node types
node_types = {node_type: [] for node_type in ['patient', 'diagnosis', 'procedure']}
print(f"node_types: {node_types}")


# Create mappings from original node IDs to new indices for each node type
node_mappings = {node_type: {} for node_type in node_types}

print(f"node_mappings: {node_mappings}")
i=0
for node, attr in ntx_graph.nodes(data=True):
    node_type = attr['type']
    if i ==0:
        print(f"node: {node}")
        print(f"attr: {attr}")
        print(f"node_type: {node_type}")

    if node not in node_mappings[node_type]:
        node_mappings[node_type][node] = len(node_mappings[node_type])

        i+=1
        print(f"node_mappings[{node_type}][{node}]: {node_mappings[node_type][node]}")
        print(f"node_mappings[node_type][node]: {node_mappings[node_type][node]}")

    if i==3:
        break

# PROMPT:
# Create a relational graph neural network for contrastive learning. The objective is to predict patient similarity.
# The data is in the networkx graph named ntx_graph. It contains 3 types of nodes: patients, diagnosis and procedure.
# It has edges with the relation of has_diagnosis and had_procedure.
# The dataframe patient_similarity_df contains patient similarity score in the column "jaccard_similarity". It has source node in column source_node and target node in column target_node. The similar rows have the column is_similar=True and dissimilar rows have is_similar=False.
# Create a relational graph neural network that uses the training data derived from ntx_graph and validation data derived from patient_similarity_df.
# Train and validate the model and print the loss in every iteration.

node_types: {'patient': [], 'diagnosis': [], 'procedure': []}
node_mappings: {'patient': {}, 'diagnosis': {}, 'procedure': {}}
node: patient-56375
attr: {'gender': 'F', 'age_bucket': 40, 'hadm_id': 176768, 'type': 'patient'}
node_type: patient
node_mappings[patient][patient-56375]: 0
node_mappings[node_type][node]: 0
node_mappings[patient][patient-31182]: 1
node_mappings[node_type][node]: 1
node_mappings[patient][patient-30931]: 2
node_mappings[node_type][node]: 2


# Neural Network

# Just for training

In [26]:
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.nn import GraphConv, HeteroConv
import random

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T

# Convert NetworkX graph to PyTorch Geometric HeteroData
def convert_to_hetero_data(nx_graph):
    data = HeteroData()

    # Define node types
    node_types_list = ['patient', 'diagnosis', 'procedure']

    # Create mappings from original node IDs to new indices for each node type
    node_mappings = {node_type: {} for node_type in node_types_list}

    # Populate node mappings
    for node, attr in nx_graph.nodes(data=True):
        node_type = attr['type']
        if node_type in node_mappings:
            if node not in node_mappings[node_type]:
                node_mappings[node_type][node] = len(node_mappings[node_type])

    # Add node features (using one-hot encoding)
    for node_type, mapping in node_mappings.items():
        num_nodes = len(mapping)
        if num_nodes > 0:
            data[node_type].x = torch.eye(num_nodes)
        else:
            # Handle case where a node type might have 0 nodes
            data[node_type].x = torch.empty((0, 0), dtype=torch.float)

    # Define edge types based on relations found in the graph
    edge_types_relations = {
        'has_diagnosis': ('patient', 'has_diagnosis', 'diagnosis'),
        'has_procedure': ('patient', 'has_procedure', 'procedure')
    }

    # Initialize edge index storage
    for src_type, rel, dst_type in edge_types_relations.values():
         data[src_type, rel, dst_type].edge_index = torch.empty((2, 0), dtype=torch.long)

    # Add edges
    edge_indices_dict = {rel: [[], []] for rel in edge_types_relations.keys()}

    for u, v, key, attr in nx_graph.edges(keys=True, data=True):
        relation = attr.get('relation')
        if relation in edge_types_relations:
            src_type, _, dst_type = edge_types_relations[relation]

            # Determine the correct source and destination based on node types
            u_type = nx_graph.nodes[u]['type']
            v_type = nx_graph.nodes[v]['type']

            # Map original node IDs to new indices
            if u_type == src_type and v_type == dst_type:
                if u in node_mappings[src_type] and v in node_mappings[dst_type]:
                    src_idx = node_mappings[src_type][u]
                    dst_idx = node_mappings[dst_type][v]
                    edge_indices_dict[relation][0].append(src_idx)
                    edge_indices_dict[relation][1].append(dst_idx)
            elif u_type == dst_type and v_type == src_type: # Handle potential reverse direction in source data
                 if v in node_mappings[src_type] and u in node_mappings[dst_type]:
                    src_idx = node_mappings[src_type][v]
                    dst_idx = node_mappings[dst_type][u]
                    edge_indices_dict[relation][0].append(src_idx)
                    edge_indices_dict[relation][1].append(dst_idx)


    # Assign edge indices to the HeteroData object
    for relation, (src_type, rel_key, dst_type) in edge_types_relations.items():
        indices = torch.tensor(edge_indices_dict[relation], dtype=torch.long)
        # Ensure edge_index is shape (2, num_edges) even if empty
        data[src_type, rel_key, dst_type].edge_index = indices.view(2, -1)

    return data, node_mappings

# Create RGCN model for contrastive learning
class RGCN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()

        # Dynamically create convolutions based on metadata
        conv_dict1 = {}
        conv_dict2 = {}

        for edge_type in metadata[1]:
            # edge_type is ('src_node_type', 'relation', 'dst_node_type')
            # Use -1 for input channels to infer automatically
            conv_dict1[edge_type] = GraphConv(-1, hidden_channels)
            conv_dict2[edge_type] = GraphConv(hidden_channels, hidden_channels)

        self.conv1 = torch_geometric.nn.HeteroConv(conv_dict1, aggr='sum')
        self.conv2 = torch_geometric.nn.HeteroConv(conv_dict2, aggr='sum')

        # Projection head specifically for patient embeddings
        # The input size to the linear layer depends on the output of conv2 for 'patient' nodes
        self.patient_proj = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels), # Adjust input size if needed
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x_dict, edge_index_dict):
        # First layer
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        # Second layer
        x_dict = self.conv2(x_dict, edge_index_dict)
        # No ReLU after the second layer before projection

        # Project patient embeddings for contrastive learning
        # Check if 'patient' key exists before projecting
        if 'patient' in x_dict:
            patient_emb = self.patient_proj(x_dict['patient'])
            return patient_emb
        else:
            # Handle cases where 'patient' embeddings might not be produced (e.g., graph structure)
            return None # Or raise an error, or return an empty tensor

# Create a contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, pairs, labels):
        # Ensure pairs and labels are on the same device as embeddings
        pairs = pairs.to(embeddings.device)
        labels = labels.to(embeddings.device)

        # Extract embeddings for pairs
        # Ensure indices in pairs are within the bounds of embeddings
        valid_indices_mask = (pairs[:, 0] < embeddings.size(0)) & (pairs[:, 1] < embeddings.size(0))
        valid_pairs = pairs[valid_indices_mask]
        valid_labels = labels[valid_indices_mask]

        if valid_pairs.size(0) == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True) # Return zero loss if no valid pairs

        embeddings1 = embeddings[valid_pairs[:, 0]]
        embeddings2 = embeddings[valid_pairs[:, 1]]

        # Calculate Euclidean distance
        distances = F.pairwise_distance(embeddings1, embeddings2)

        # Calculate loss based on labels (0 for similar, 1 for dissimilar)
        # Loss for similar pairs (label=0): distance^2
        # Loss for dissimilar pairs (label=1): max(0, margin - distance)^2
        loss_similar = (1 - valid_labels) * torch.pow(distances, 2)
        loss_dissimilar = valid_labels * torch.pow(torch.clamp(self.margin - distances, min=0.0), 2)

        loss = loss_similar + loss_dissimilar
        return loss.mean()

# Process NetworkX graph
print("Converting NetworkX graph to HeteroData...")
data, node_mappings = convert_to_hetero_data(ntx_graph)

# Add reverse edges to make the graph effectively undirected for message passing
transform = T.ToUndirected()
data = transform(data)

# Create a lookup from original patient node ID to new sequential index
patient_node_lookup = {node: idx for node, idx in node_mappings.get('patient', {}).items()}

# Process similarity data for training
print("Processing similarity data...")
sim_pairs_list = []
sim_labels_list = []

# Consider using a subset for faster training/debugging if the dataset is large
# Determine sample size (e.g., 50k or full dataset)
sample_size = min(50000, len(patient_similarity_df))
# Use random sampling if taking a subset
subset_df = patient_similarity_df.sample(n=sample_size, random_state=42) if sample_size < len(patient_similarity_df) else patient_similarity_df


for _, row in subset_df.iterrows():
    source = row['source_node']
    target = row['target_node']
    # Check if both source and target patients are in our mapping
    if source in patient_node_lookup and target in patient_node_lookup:
        source_idx = patient_node_lookup[source]
        target_idx = patient_node_lookup[target]
        sim_pairs_list.append([source_idx, target_idx])
        # Use 0 for similar, 1 for dissimilar for the contrastive loss formula used
        sim_labels_list.append(0.0 if row['is_similar'] else 1.0)

if not sim_pairs_list:
     raise ValueError("No valid training pairs found. Check patient IDs in similarity data and graph.")

sim_pairs = torch.tensor(sim_pairs_list, dtype=torch.long)
sim_labels = torch.tensor(sim_labels_list, dtype=torch.float)

print(f"Created {len(sim_pairs)} training pairs from {sample_size} samples.")

# --- Model, Optimizer, and Training Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Ensure data is on the correct device
data = data.to(device)
sim_pairs = sim_pairs.to(device)
sim_labels = sim_labels.to(device)


# Create the model
hidden_channels = 64
out_channels = 32
model = RGCN(hidden_channels, out_channels, data.metadata()).to(device)

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Added weight decay
criterion = ContrastiveLoss(margin=1.0) # Adjusted margin

# Training loop
def train():
    model.train()
    optimizer.zero_grad()

    # Get patient embeddings
    patient_embeddings = model(data.x_dict, data.edge_index_dict)

    if patient_embeddings is None or patient_embeddings.size(0) == 0:
         print("Warning: No patient embeddings generated.")
         return 0.0 # Or handle appropriately

    # Compute loss using only the patient embeddings
    loss = criterion(patient_embeddings, sim_pairs, sim_labels)

    if loss.requires_grad:
        # Backpropagation
        loss.backward()
        optimizer.step()
        return loss.item()
    else:
        # Handle cases where loss does not require gradients (e.g., no valid pairs)
        print("Warning: Loss does not require gradients.")
        return loss.item() # Return the scalar value

# Train the model
print("Training model...")
num_epochs = 100 # Increased epochs
for epoch in range(num_epochs):
    loss = train()
    if (epoch + 1) % 10 == 0: # Print loss every 10 epochs
        print(f'Epoch: {epoch+1:03d}, Loss: {loss:.4f}')

print("Training finished.")

Converting NetworkX graph to HeteroData...
Processing similarity data...
Created 50000 training pairs from 50000 samples.
Using device: cuda
Training model...
Epoch: 010, Loss: 0.8461
Epoch: 020, Loss: 0.0626
Epoch: 030, Loss: 0.0411
Epoch: 040, Loss: 0.0287
Epoch: 050, Loss: 0.0227
Epoch: 060, Loss: 0.0197
Epoch: 070, Loss: 0.0180
Epoch: 080, Loss: 0.0169
Epoch: 090, Loss: 0.0165
Epoch: 100, Loss: 0.0156
Training finished.


# Save and Load Model

In [27]:
# save the model to a file
MODEL_FILE_NAME = f"rgcn_model_38.pth"
torch.save(model.state_dict(), MODEL_FILE_NAME)
print(f"Model saved to {MODEL_FILE_NAME}")

Model saved to rgcn_model_38.pth


# up is just for training

# just for inference

In [13]:
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.nn import GraphConv, HeteroConv
import random

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T

# Convert NetworkX graph to PyTorch Geometric HeteroData
def convert_to_hetero_data(nx_graph):
    data = HeteroData()

    # Define node types
    node_types_list = ['patient', 'diagnosis', 'procedure']

    # Create mappings from original node IDs to new indices for each node type
    node_mappings = {node_type: {} for node_type in node_types_list}

    # Populate node mappings
    for node, attr in nx_graph.nodes(data=True):
        node_type = attr['type']
        if node_type in node_mappings:
            if node not in node_mappings[node_type]:
                node_mappings[node_type][node] = len(node_mappings[node_type])

    # Add node features (using one-hot encoding)
    for node_type, mapping in node_mappings.items():
        num_nodes = len(mapping)
        if num_nodes > 0:
            data[node_type].x = torch.eye(num_nodes)
        else:
            # Handle case where a node type might have 0 nodes
            data[node_type].x = torch.empty((0, 0), dtype=torch.float)

    # Define edge types based on relations found in the graph
    edge_types_relations = {
        'has_diagnosis': ('patient', 'has_diagnosis', 'diagnosis'),
        'has_procedure': ('patient', 'has_procedure', 'procedure')
    }

    # Initialize edge index storage
    for src_type, rel, dst_type in edge_types_relations.values():
         data[src_type, rel, dst_type].edge_index = torch.empty((2, 0), dtype=torch.long)

    # Add edges
    edge_indices_dict = {rel: [[], []] for rel in edge_types_relations.keys()}

    for u, v, key, attr in nx_graph.edges(keys=True, data=True):
        relation = attr.get('relation')
        if relation in edge_types_relations:
            src_type, _, dst_type = edge_types_relations[relation]

            # Determine the correct source and destination based on node types
            u_type = nx_graph.nodes[u]['type']
            v_type = nx_graph.nodes[v]['type']

            # Map original node IDs to new indices
            if u_type == src_type and v_type == dst_type:
                if u in node_mappings[src_type] and v in node_mappings[dst_type]:
                    src_idx = node_mappings[src_type][u]
                    dst_idx = node_mappings[dst_type][v]
                    edge_indices_dict[relation][0].append(src_idx)
                    edge_indices_dict[relation][1].append(dst_idx)
            elif u_type == dst_type and v_type == src_type: # Handle potential reverse direction in source data
                 if v in node_mappings[src_type] and u in node_mappings[dst_type]:
                    src_idx = node_mappings[src_type][v]
                    dst_idx = node_mappings[dst_type][u]
                    edge_indices_dict[relation][0].append(src_idx)
                    edge_indices_dict[relation][1].append(dst_idx)


    # Assign edge indices to the HeteroData object
    for relation, (src_type, rel_key, dst_type) in edge_types_relations.items():
        indices = torch.tensor(edge_indices_dict[relation], dtype=torch.long)
        # Ensure edge_index is shape (2, num_edges) even if empty
        data[src_type, rel_key, dst_type].edge_index = indices.view(2, -1)

    return data, node_mappings

# Create RGCN model for contrastive learning
class RGCN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()

        # Dynamically create convolutions based on metadata
        conv_dict1 = {}
        conv_dict2 = {}

        for edge_type in metadata[1]:
            # edge_type is ('src_node_type', 'relation', 'dst_node_type')
            # Use -1 for input channels to infer automatically
            conv_dict1[edge_type] = GraphConv(-1, hidden_channels)
            conv_dict2[edge_type] = GraphConv(hidden_channels, hidden_channels)

        self.conv1 = torch_geometric.nn.HeteroConv(conv_dict1, aggr='sum')
        self.conv2 = torch_geometric.nn.HeteroConv(conv_dict2, aggr='sum')

        # Projection head specifically for patient embeddings
        # The input size to the linear layer depends on the output of conv2 for 'patient' nodes
        self.patient_proj = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels), # Adjust input size if needed
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x_dict, edge_index_dict):
        # First layer
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        # Second layer
        x_dict = self.conv2(x_dict, edge_index_dict)
        # No ReLU after the second layer before projection

        # Project patient embeddings for contrastive learning
        # Check if 'patient' key exists before projecting
        if 'patient' in x_dict:
            patient_emb = self.patient_proj(x_dict['patient'])
            return patient_emb
        else:
            # Handle cases where 'patient' embeddings might not be produced (e.g., graph structure)
            return None # Or raise an error, or return an empty tensor

# Create a contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, pairs, labels):
        # Ensure pairs and labels are on the same device as embeddings
        pairs = pairs.to(embeddings.device)
        labels = labels.to(embeddings.device)

        # Extract embeddings for pairs
        # Ensure indices in pairs are within the bounds of embeddings
        valid_indices_mask = (pairs[:, 0] < embeddings.size(0)) & (pairs[:, 1] < embeddings.size(0))
        valid_pairs = pairs[valid_indices_mask]
        valid_labels = labels[valid_indices_mask]

        if valid_pairs.size(0) == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True) # Return zero loss if no valid pairs

        embeddings1 = embeddings[valid_pairs[:, 0]]
        embeddings2 = embeddings[valid_pairs[:, 1]]

        # Calculate Euclidean distance
        distances = F.pairwise_distance(embeddings1, embeddings2)

        # Calculate loss based on labels (0 for similar, 1 for dissimilar)
        # Loss for similar pairs (label=0): distance^2
        # Loss for dissimilar pairs (label=1): max(0, margin - distance)^2
        loss_similar = (1 - valid_labels) * torch.pow(distances, 2)
        loss_dissimilar = valid_labels * torch.pow(torch.clamp(self.margin - distances, min=0.0), 2)

        loss = loss_similar + loss_dissimilar
        return loss.mean()

# Process NetworkX graph
print("Converting NetworkX graph to HeteroData...")
data, node_mappings = convert_to_hetero_data(ntx_graph)

# Add reverse edges to make the graph effectively undirected for message passing
transform = T.ToUndirected()
data = transform(data)

# Create a lookup from original patient node ID to new sequential index
patient_node_lookup = {node: idx for node, idx in node_mappings.get('patient', {}).items()}

# Process similarity data for training
print("Processing similarity data...")
sim_pairs_list = []
sim_labels_list = []

# Consider using a subset for faster training/debugging if the dataset is large
# Determine sample size (e.g., 50k or full dataset)
sample_size = min(50000, len(patient_similarity_df))
# Use random sampling if taking a subset
subset_df = patient_similarity_df.sample(n=sample_size, random_state=42) if sample_size < len(patient_similarity_df) else patient_similarity_df


for _, row in subset_df.iterrows():
    source = row['source_node']
    target = row['target_node']
    # Check if both source and target patients are in our mapping
    if source in patient_node_lookup and target in patient_node_lookup:
        source_idx = patient_node_lookup[source]
        target_idx = patient_node_lookup[target]
        sim_pairs_list.append([source_idx, target_idx])
        # Use 0 for similar, 1 for dissimilar for the contrastive loss formula used
        sim_labels_list.append(0.0 if row['is_similar'] else 1.0)

if not sim_pairs_list:
     raise ValueError("No valid training pairs found. Check patient IDs in similarity data and graph.")

sim_pairs = torch.tensor(sim_pairs_list, dtype=torch.long)
sim_labels = torch.tensor(sim_labels_list, dtype=torch.float)

print(f"Created {len(sim_pairs)} training pairs from {sample_size} samples.")

# --- Model, Optimizer, and Training Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Ensure data is on the correct device
data = data.to(device)
sim_pairs = sim_pairs.to(device)
sim_labels = sim_labels.to(device)


# Create the model
hidden_channels = 64
out_channels = 32
model = RGCN(hidden_channels, out_channels, data.metadata()).to(device)

# # Set up optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Added weight decay
# criterion = ContrastiveLoss(margin=1.0) # Adjusted margin

Converting NetworkX graph to HeteroData...
Processing similarity data...
Created 50000 training pairs from 50000 samples.
Using device: cuda


## Save model

## Load model

In [28]:

# load the model from a file
def load_model(model, file_path):
    model.load_state_dict(torch.load(file_path))
    model.eval()
    return model

# load the model from a file
model = RGCN(hidden_channels, out_channels, data.metadata()).to(device)
model = load_model(model, 'rgcn_model_38.pth')
# print(f"Model loaded from {MODEL_FILE_NAME}")


# print model summary
def print_model_summary(model):
    print("Model Summary:")
    print(model)
    # print("\nModel Parameters:")
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(f"{name}: {param.data.size()}")
    # print("\nModel Forward Pass:")
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(f"{name}: {param.data.size()}")
    # print("\nModel Forward Pass:")
    # print("Input Data:")
    # print(f"data.x_dict: {data.x_dict}")
    # print(f"data.edge_index_dict: {data.edge_index_dict}")
    print("Output Data:")
    print(f"patient_embeddings: {model(data.x_dict, data.edge_index_dict)}")
    print("\nModel Summary Finished.")

print_model_summary(model)
# print model summary
print("Model summary printed.")

Model Summary:
RGCN(
  (conv1): HeteroConv(num_relations=4)
  (conv2): HeteroConv(num_relations=4)
  (patient_proj): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
  )
)
Output Data:
patient_embeddings: tensor([[ 3.0191e-01,  7.3386e-01,  6.1928e-02,  ...,  4.4904e-01,
          2.8683e-02, -2.7177e-01],
        [ 2.9513e-01,  3.2462e-01, -1.0928e-01,  ...,  2.0872e-01,
          1.4677e-01, -3.7743e-02],
        [ 1.7052e-03, -1.0870e-02,  1.1687e-02,  ..., -8.2174e-03,
          3.3596e-04, -1.1501e-02],
        ...,
        [ 3.1578e-01,  4.8475e-01,  2.7280e-02,  ...,  4.6020e-01,
          2.7777e-01, -1.6426e-01],
        [ 1.8604e-01,  9.6066e-01,  1.3552e-01,  ...,  4.4749e-01,
         -7.3286e-02, -2.1700e-01],
        [ 3.1369e-01,  3.5424e-01, -1.0978e-01,  ...,  2.0534e-01,
          1.4170e-01, -3.5695e-02]], device='cuda:0', grad_fn=<AddmmBackward0>)

Model Summary Finish

In [29]:
data.x_dict.keys()

dict_keys(['patient', 'diagnosis', 'procedure'])

In [30]:
len(data.x_dict['patient'])

17678

In [31]:
embeddings = model(data.x_dict, data.edge_index_dict)

In [32]:
type(embeddings)

torch.Tensor

In [33]:
embeddings.shape

torch.Size([17678, 32])

In [34]:
# Save to pickle file
with open('embeddings_38_pickle.pkl', 'wb') as f:
    pickle.dump(embeddings, f)