# Imports

In [23]:
import pickle
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
# import torch geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.data import Data, DataLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GraphConv
# from torch_geometric.nn import MessagePassing
# from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
# from torch_geometric.utils import degree, to_dense_adj, to_dense_batch, to_undirected
# from torch_geometric.utils import from_networkx, to_networkx
# from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
# from torch_geometric.utils import dropout_adj, to_undirected, add_self_loops
# from torch_geometric.utils import to_dense_adj, to_dense_batch, dense_to_sparse

# print versions
print("torch version: ", torch.__version__)
print("torch geometric version: ", torch_geometric.__version__)
print("torch cuda version: ", torch.version.cuda)
print("torch cuda available: ", torch.cuda.is_available())
# print("torch cuda device count: ", torch.cuda.device_count())
# print("torch cuda current device name: ", torch.cuda.get_device_name(torch.cuda.current_device()))

# print versions
print(f"Pandas version: {pd.__version__}")
print(f"Numpy version: {np.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"NetworkX version: {nx.__version__}")

# set pandas display options to show all columns and rows without truncation
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
print("\nRemoved truncation of columns")


torch version:  2.6.0+cpu
torch geometric version:  2.6.1
torch cuda version:  None
torch cuda available:  False
Pandas version: 2.2.3
Numpy version: 2.2.4
Matplotlib version: 3.10.1
NetworkX version: 3.4.2

Removed truncation of columns


# Load data

## Load networkx graph

In [None]:
# load the networkx graph
def load_graph(graph_path):
    with open(graph_path, 'rb') as f:
        graph = pickle.load(f)
    return graph

CURR_DIR_PATH = os.getcwd()
PICKLE_GRAPH_FILE_NAME = f"{CURR_DIR_PATH}\\pickle\\max_20000_nodes_graph.gpickle"
PICKLE_GRAPH_FILE_NAME = os.path.abspath(PICKLE_GRAPH_FILE_NAME)

print(f"Loading graph from {PICKLE_GRAPH_FILE_NAME}...")
ntx_graph = load_graph(PICKLE_GRAPH_FILE_NAME)
print(f"Graph loaded. Number of nodes: {ntx_graph.number_of_nodes()}, number of edges: {ntx_graph.number_of_edges()}")


Loading graph from d:\Repos\ut-health-final-proj\pickle\max_20000_nodes_graph.gpickle...
Graph loaded. Number of nodes: 17654, number of edges: 172728


In [16]:
# count nodes of type patient, diagnosis and procedure
def count_node_types(graph):
    node_types = {}
    for node in graph.nodes():
        node_type = graph.nodes[node]['type']
        if node_type not in node_types:
            node_types[node_type] = 0
        node_types[node_type] += 1
    return node_types

node_types = count_node_types(ntx_graph)
print(f"Node counts: {node_types}")

# count edges named had_procedure, has_diagnosis
def count_edge_types(graph):
    edge_types = {}
    for u, v, key in graph.edges(keys=True):
        edge_type = graph.edges[u, v, key]['relation']
        if edge_type not in edge_types:
            edge_types[edge_type] = 0
        edge_types[edge_type] += 1
    return edge_types

edge_types = count_edge_types(ntx_graph)
print(f"Edge counts: {edge_types}")

Node counts: {'patient': 11630, 'diagnosis': 4680, 'procedure': 1344}
Edge counts: {'has_diagnosis': 123987, 'has_procedure': 48741}


## Load patient node similarity dataframe

In [8]:
# load the patient similarity dataframe
def load_patient_similarity_df(df_path):
    with open(df_path, 'rb') as f:
        df = pickle.load(f)
    return df

CURR_DIR_PATH = os.getcwd()
PATIENT_SIMILARITY_DF_FILE_NAME = f"{CURR_DIR_PATH}\\pickle\\max_20000_nodes_similarity.gpickle"
PATIENT_SIMILARITY_DF_FILE_NAME = os.path.abspath(PATIENT_SIMILARITY_DF_FILE_NAME)

print(f"Loading patient similarity dataframe from {PATIENT_SIMILARITY_DF_FILE_NAME}...")
patient_similarity_df = load_patient_similarity_df(PATIENT_SIMILARITY_DF_FILE_NAME)
print(f"Patient similarity dataframe loaded. Number of rows: {patient_similarity_df.shape[0]}")

Loading patient similarity dataframe from d:\Repos\ut-health-final-proj\pickle\max_20000_nodes_similarity.gpickle...
Patient similarity dataframe loaded. Number of rows: 1288402


In [9]:
patient_similarity_df.head(5)

Unnamed: 0,patient1,patient2,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
423,patient-4074,patient-10139,4074,10139,0.304348,False,False,True
9670,patient-4074,patient-26572,4074,26572,0.3,True,True,True
16410,patient-90889,patient-84020,90889,84020,0.344828,True,False,True
17072,patient-90889,patient-67648,90889,67648,0.4,False,False,True
19221,patient-90889,patient-6138,90889,6138,0.333333,True,False,True


In [None]:
# rename column patient1 to source and patient2 to target
patient_similarity_df.rename(columns={'patient1': 'source_node', 'patient2': 'target_node'}, inplace=True)
patient_similarity_df.head(5)

Unnamed: 0,source_node,target_node,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
423,patient-4074,patient-10139,4074,10139,0.304348,False,False,True
9670,patient-4074,patient-26572,4074,26572,0.3,True,True,True
16410,patient-90889,patient-84020,90889,84020,0.344828,True,False,True
17072,patient-90889,patient-67648,90889,67648,0.4,False,False,True
19221,patient-90889,patient-6138,90889,6138,0.333333,True,False,True


# Create R-GCN and train the model

### DEBUG

In [None]:
# Get node types
node_types = {node_type: [] for node_type in ['patient', 'diagnosis', 'procedure']}
print(f"node_types: {node_types}")

    
# Create mappings from original node IDs to new indices for each node type
node_mappings = {node_type: {} for node_type in node_types}

print(f"node_mappings: {node_mappings}")
i=0
for node, attr in ntx_graph.nodes(data=True):
    node_type = attr['type']
    if i ==0:
        print(f"node: {node}")
        print(f"attr: {attr}")
        print(f"node_type: {node_type}")

    if node not in node_mappings[node_type]:
        node_mappings[node_type][node] = len(node_mappings[node_type])

        i+=1
        print(f"node_mappings[{node_type}][{node}]: {node_mappings[node_type][node]}")
        print(f"node_mappings[node_type][node]: {node_mappings[node_type][node]}")

    if i==3:
        break

# PROMPT:
# Create a relational graph neural network for contrastive learning. The objective is to predict patient similarity. 
# The data is in the networkx graph named ntx_graph. It contains 3 types of nodes: patients, diagnosis and procedure.
# It has edges with the relation of has_diagnosis and had_procedure. 
# The dataframe patient_similarity_df contains patient similarity score in the column "jaccard_similarity". It has source node in column source_node and target node in column target_node. The similar rows have the column is_similar=True and dissimilar rows have is_similar=False.
# Create a relational graph neural network that uses the training data derived from ntx_graph and validation data derived from patient_similarity_df.
# Train and validate the model and print the loss in every iteration.

node_types: {'patient': [], 'diagnosis': [], 'procedure': []}
node_mappings: {'patient': {}, 'diagnosis': {}, 'procedure': {}}
node: patient-4074
attr: {'gender': 'M', 'age_bucket': 80, 'hadm_id': 137421, 'type': 'patient'}
node_type: patient
node_mappings[patient][patient-4074]: 0
node_mappings[node_type][node]: 0
node_mappings[patient][patient-90889]: 1
node_mappings[node_type][node]: 1
node_mappings[patient][patient-72753]: 2
node_mappings[node_type][node]: 2


# Neural Network

In [37]:
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.nn import GraphConv, HeteroConv
import random

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T

# Convert NetworkX graph to PyTorch Geometric HeteroData
def convert_to_hetero_data(nx_graph):
    data = HeteroData()
    
    # Define node types
    node_types_list = ['patient', 'diagnosis', 'procedure']
    
    # Create mappings from original node IDs to new indices for each node type
    node_mappings = {node_type: {} for node_type in node_types_list}

    # Populate node mappings
    for node, attr in nx_graph.nodes(data=True):
        node_type = attr['type']
        if node_type in node_mappings:
            if node not in node_mappings[node_type]:
                node_mappings[node_type][node] = len(node_mappings[node_type])
    
    # Add node features (using one-hot encoding)
    for node_type, mapping in node_mappings.items():
        num_nodes = len(mapping)
        if num_nodes > 0:
            data[node_type].x = torch.eye(num_nodes)
        else:
            # Handle case where a node type might have 0 nodes
            data[node_type].x = torch.empty((0, 0), dtype=torch.float)

    # Define edge types based on relations found in the graph
    edge_types_relations = {
        'has_diagnosis': ('patient', 'has_diagnosis', 'diagnosis'),
        'has_procedure': ('patient', 'has_procedure', 'procedure')
    }

    # Initialize edge index storage
    for src_type, rel, dst_type in edge_types_relations.values():
         data[src_type, rel, dst_type].edge_index = torch.empty((2, 0), dtype=torch.long)

    # Add edges
    edge_indices_dict = {rel: [[], []] for rel in edge_types_relations.keys()}

    for u, v, key, attr in nx_graph.edges(keys=True, data=True):
        relation = attr.get('relation')
        if relation in edge_types_relations:
            src_type, _, dst_type = edge_types_relations[relation]
            
            # Determine the correct source and destination based on node types
            u_type = nx_graph.nodes[u]['type']
            v_type = nx_graph.nodes[v]['type']

            # Map original node IDs to new indices
            if u_type == src_type and v_type == dst_type:
                if u in node_mappings[src_type] and v in node_mappings[dst_type]:
                    src_idx = node_mappings[src_type][u]
                    dst_idx = node_mappings[dst_type][v]
                    edge_indices_dict[relation][0].append(src_idx)
                    edge_indices_dict[relation][1].append(dst_idx)
            elif u_type == dst_type and v_type == src_type: # Handle potential reverse direction in source data
                 if v in node_mappings[src_type] and u in node_mappings[dst_type]:
                    src_idx = node_mappings[src_type][v]
                    dst_idx = node_mappings[dst_type][u]
                    edge_indices_dict[relation][0].append(src_idx)
                    edge_indices_dict[relation][1].append(dst_idx)


    # Assign edge indices to the HeteroData object
    for relation, (src_type, rel_key, dst_type) in edge_types_relations.items():
        indices = torch.tensor(edge_indices_dict[relation], dtype=torch.long)
        # Ensure edge_index is shape (2, num_edges) even if empty
        data[src_type, rel_key, dst_type].edge_index = indices.view(2, -1)

    return data, node_mappings

# Create RGCN model for contrastive learning
class RGCN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        
        # Dynamically create convolutions based on metadata
        conv_dict1 = {}
        conv_dict2 = {}
        
        for edge_type in metadata[1]:
            # edge_type is ('src_node_type', 'relation', 'dst_node_type')
            # Use -1 for input channels to infer automatically
            conv_dict1[edge_type] = GraphConv(-1, hidden_channels)
            conv_dict2[edge_type] = GraphConv(hidden_channels, hidden_channels)
            
        self.conv1 = torch_geometric.nn.HeteroConv(conv_dict1, aggr='sum')
        self.conv2 = torch_geometric.nn.HeteroConv(conv_dict2, aggr='sum')
        
        # Projection head specifically for patient embeddings
        # The input size to the linear layer depends on the output of conv2 for 'patient' nodes
        self.patient_proj = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels), # Adjust input size if needed
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )
        
    def forward(self, x_dict, edge_index_dict):
        # First layer
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        
        # Second layer
        x_dict = self.conv2(x_dict, edge_index_dict)
        # No ReLU after the second layer before projection
        
        # Project patient embeddings for contrastive learning
        # Check if 'patient' key exists before projecting
        if 'patient' in x_dict:
            patient_emb = self.patient_proj(x_dict['patient'])
            return patient_emb
        else:
            # Handle cases where 'patient' embeddings might not be produced (e.g., graph structure)
            return None # Or raise an error, or return an empty tensor

# Create a contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, embeddings, pairs, labels):
        # Ensure pairs and labels are on the same device as embeddings
        pairs = pairs.to(embeddings.device)
        labels = labels.to(embeddings.device)

        # Extract embeddings for pairs
        # Ensure indices in pairs are within the bounds of embeddings
        valid_indices_mask = (pairs[:, 0] < embeddings.size(0)) & (pairs[:, 1] < embeddings.size(0))
        valid_pairs = pairs[valid_indices_mask]
        valid_labels = labels[valid_indices_mask]

        if valid_pairs.size(0) == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True) # Return zero loss if no valid pairs

        embeddings1 = embeddings[valid_pairs[:, 0]]
        embeddings2 = embeddings[valid_pairs[:, 1]]
        
        # Calculate Euclidean distance
        distances = F.pairwise_distance(embeddings1, embeddings2)
        
        # Calculate loss based on labels (0 for similar, 1 for dissimilar)
        # Loss for similar pairs (label=0): distance^2
        # Loss for dissimilar pairs (label=1): max(0, margin - distance)^2
        loss_similar = (1 - valid_labels) * torch.pow(distances, 2)
        loss_dissimilar = valid_labels * torch.pow(torch.clamp(self.margin - distances, min=0.0), 2)
        
        loss = loss_similar + loss_dissimilar
        return loss.mean()

# Process NetworkX graph
print("Converting NetworkX graph to HeteroData...")
data, node_mappings = convert_to_hetero_data(ntx_graph)

# Add reverse edges to make the graph effectively undirected for message passing
transform = T.ToUndirected()
data = transform(data)

# Create a lookup from original patient node ID to new sequential index
patient_node_lookup = {node: idx for node, idx in node_mappings.get('patient', {}).items()}

# Process similarity data for training
print("Processing similarity data...")
sim_pairs_list = []
sim_labels_list = []

# Consider using a subset for faster training/debugging if the dataset is large
# Determine sample size (e.g., 50k or full dataset)
sample_size = min(50000, len(patient_similarity_df))
# Use random sampling if taking a subset
subset_df = patient_similarity_df.sample(n=sample_size, random_state=42) if sample_size < len(patient_similarity_df) else patient_similarity_df


for _, row in subset_df.iterrows():
    source = row['source_node']
    target = row['target_node']
    # Check if both source and target patients are in our mapping
    if source in patient_node_lookup and target in patient_node_lookup:
        source_idx = patient_node_lookup[source]
        target_idx = patient_node_lookup[target]
        sim_pairs_list.append([source_idx, target_idx])
        # Use 0 for similar, 1 for dissimilar for the contrastive loss formula used
        sim_labels_list.append(0.0 if row['is_similar'] else 1.0)

if not sim_pairs_list:
     raise ValueError("No valid training pairs found. Check patient IDs in similarity data and graph.")

sim_pairs = torch.tensor(sim_pairs_list, dtype=torch.long)
sim_labels = torch.tensor(sim_labels_list, dtype=torch.float)

print(f"Created {len(sim_pairs)} training pairs from {sample_size} samples.")

# --- Model, Optimizer, and Training Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Ensure data is on the correct device
data = data.to(device)
sim_pairs = sim_pairs.to(device)
sim_labels = sim_labels.to(device)


# Create the model
hidden_channels = 64
out_channels = 32
model = RGCN(hidden_channels, out_channels, data.metadata()).to(device)

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Added weight decay
criterion = ContrastiveLoss(margin=1.0) # Adjusted margin

# Training loop
def train():
    model.train()
    optimizer.zero_grad()
    
    # Get patient embeddings
    patient_embeddings = model(data.x_dict, data.edge_index_dict)
    
    if patient_embeddings is None or patient_embeddings.size(0) == 0:
         print("Warning: No patient embeddings generated.")
         return 0.0 # Or handle appropriately

    # Compute loss using only the patient embeddings
    loss = criterion(patient_embeddings, sim_pairs, sim_labels)
    
    if loss.requires_grad:
        # Backpropagation
        loss.backward()
        optimizer.step()
        return loss.item()
    else:
        # Handle cases where loss does not require gradients (e.g., no valid pairs)
        print("Warning: Loss does not require gradients.")
        return loss.item() # Return the scalar value

# Train the model
print("Training model...")
num_epochs = 100 # Increased epochs
for epoch in range(num_epochs):
    loss = train()
    if (epoch + 1) % 10 == 0: # Print loss every 10 epochs
        print(f'Epoch: {epoch+1:03d}, Loss: {loss:.4f}')

print("Training finished.")

Converting NetworkX graph to HeteroData...
Processing similarity data...
Created 50000 training pairs from 50000 samples.
Using device: cpu
Training model...
Epoch: 010, Loss: 0.1938
Epoch: 020, Loss: 0.0502
Epoch: 030, Loss: 0.0296
Epoch: 040, Loss: 0.0219
Epoch: 050, Loss: 0.0174
Epoch: 060, Loss: 0.0152
Epoch: 070, Loss: 0.0136
Epoch: 080, Loss: 0.0124
Epoch: 090, Loss: 0.0116
Epoch: 100, Loss: 0.0111
Training finished.


# Save and Load Model

## Save model

In [39]:
# save the model to a file
MODEL_FILE_NAME = f"{CURR_DIR_PATH}\\models\\rgcn_model.pth"
torch.save(model.state_dict(), MODEL_FILE_NAME)
print(f"Model saved to {MODEL_FILE_NAME}")

Model saved to d:\Repos\ut-health-final-proj\models\rgcn_model.pth


## Load model

In [41]:

# load the model from a file
def load_model(model, file_path):
    model.load_state_dict(torch.load(file_path))
    model.eval()
    return model

# load the model from a file
model = RGCN(hidden_channels, out_channels, data.metadata()).to(device)
model = load_model(model, MODEL_FILE_NAME)
print(f"Model loaded from {MODEL_FILE_NAME}")


# print model summary
def print_model_summary(model):
    print("Model Summary:")
    print(model)
    # print("\nModel Parameters:")
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(f"{name}: {param.data.size()}")
    # print("\nModel Forward Pass:")
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(f"{name}: {param.data.size()}")
    # print("\nModel Forward Pass:")
    # print("Input Data:")
    # print(f"data.x_dict: {data.x_dict}")
    # print(f"data.edge_index_dict: {data.edge_index_dict}")
    print("Output Data:")
    print(f"patient_embeddings: {model(data.x_dict, data.edge_index_dict)}")
    print("\nModel Summary Finished.")

print_model_summary(model)
# print model summary
print("Model summary printed.")

Model loaded from d:\Repos\ut-health-final-proj\models\rgcn_model.pth
Model Summary:
RGCN(
  (conv1): HeteroConv(num_relations=4)
  (conv2): HeteroConv(num_relations=4)
  (patient_proj): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
  )
)
Output Data:
patient_embeddings: tensor([[-0.2149, -0.3542,  0.1598,  ..., -0.8986, -0.2081, -0.7109],
        [-0.2317, -0.0439,  0.0421,  ..., -0.2855,  0.0356, -0.1715],
        [-0.1430, -0.2358,  0.1536,  ..., -0.4589, -0.0924, -0.4197],
        ...,
        [-0.6093, -0.2219,  0.1109,  ..., -1.7633, -0.8697, -1.2796],
        [-0.3326, -0.2799,  0.2922,  ..., -0.9005, -0.1450, -0.5438],
        [-0.2193, -0.0496,  0.0270,  ..., -0.2097,  0.0538, -0.0965]],
       grad_fn=<AddmmBackward0>)

Model Summary Finished.
Model summary printed.
