# Imports

In [23]:
import pickle
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
# import torch geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.data import Data, DataLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GraphConv
# from torch_geometric.nn import MessagePassing
# from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
# from torch_geometric.utils import degree, to_dense_adj, to_dense_batch, to_undirected
# from torch_geometric.utils import from_networkx, to_networkx
# from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
# from torch_geometric.utils import dropout_adj, to_undirected, add_self_loops
# from torch_geometric.utils import to_dense_adj, to_dense_batch, dense_to_sparse

# print versions
print("torch version: ", torch.__version__)
print("torch geometric version: ", torch_geometric.__version__)
print("torch cuda version: ", torch.version.cuda)
print("torch cuda available: ", torch.cuda.is_available())
# print("torch cuda device count: ", torch.cuda.device_count())
# print("torch cuda current device name: ", torch.cuda.get_device_name(torch.cuda.current_device()))

# print versions
print(f"Pandas version: {pd.__version__}")
print(f"Numpy version: {np.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"NetworkX version: {nx.__version__}")

# set pandas display options to show all columns and rows without truncation
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
print("\nRemoved truncation of columns")


torch version:  2.6.0+cpu
torch geometric version:  2.6.1
torch cuda version:  None
torch cuda available:  False
Pandas version: 2.2.3
Numpy version: 2.2.4
Matplotlib version: 3.10.1
NetworkX version: 3.4.2

Removed truncation of columns


# Load data

## Load networkx graph

In [None]:
# load the networkx graph
def load_graph(graph_path):
    with open(graph_path, 'rb') as f:
        graph = pickle.load(f)
    return graph

CURR_DIR_PATH = os.getcwd()
PICKLE_GRAPH_FILE_NAME = f"{CURR_DIR_PATH}\\pickle\\max_20000_nodes_graph.gpickle"
PICKLE_GRAPH_FILE_NAME = os.path.abspath(PICKLE_GRAPH_FILE_NAME)

print(f"Loading graph from {PICKLE_GRAPH_FILE_NAME}...")
ntx_graph = load_graph(PICKLE_GRAPH_FILE_NAME)
print(f"Graph loaded. Number of nodes: {ntx_graph.number_of_nodes()}, number of edges: {ntx_graph.number_of_edges()}")


Loading graph from d:\Repos\ut-health-final-proj\pickle\max_20000_nodes_graph.gpickle...
Graph loaded. Number of nodes: 17654, number of edges: 172728


In [16]:
# count nodes of type patient, diagnosis and procedure
def count_node_types(graph):
    node_types = {}
    for node in graph.nodes():
        node_type = graph.nodes[node]['type']
        if node_type not in node_types:
            node_types[node_type] = 0
        node_types[node_type] += 1
    return node_types

node_types = count_node_types(ntx_graph)
print(f"Node counts: {node_types}")

# count edges named had_procedure, has_diagnosis
def count_edge_types(graph):
    edge_types = {}
    for u, v, key in graph.edges(keys=True):
        edge_type = graph.edges[u, v, key]['relation']
        if edge_type not in edge_types:
            edge_types[edge_type] = 0
        edge_types[edge_type] += 1
    return edge_types

edge_types = count_edge_types(ntx_graph)
print(f"Edge counts: {edge_types}")

Node counts: {'patient': 11630, 'diagnosis': 4680, 'procedure': 1344}
Edge counts: {'has_diagnosis': 123987, 'has_procedure': 48741}


## Load patient node similarity dataframe

In [8]:
# load the patient similarity dataframe
def load_patient_similarity_df(df_path):
    with open(df_path, 'rb') as f:
        df = pickle.load(f)
    return df

CURR_DIR_PATH = os.getcwd()
PATIENT_SIMILARITY_DF_FILE_NAME = f"{CURR_DIR_PATH}\\pickle\\max_20000_nodes_similarity.gpickle"
PATIENT_SIMILARITY_DF_FILE_NAME = os.path.abspath(PATIENT_SIMILARITY_DF_FILE_NAME)

print(f"Loading patient similarity dataframe from {PATIENT_SIMILARITY_DF_FILE_NAME}...")
patient_similarity_df = load_patient_similarity_df(PATIENT_SIMILARITY_DF_FILE_NAME)
print(f"Patient similarity dataframe loaded. Number of rows: {patient_similarity_df.shape[0]}")

Loading patient similarity dataframe from d:\Repos\ut-health-final-proj\pickle\max_20000_nodes_similarity.gpickle...
Patient similarity dataframe loaded. Number of rows: 1288402


In [9]:
patient_similarity_df.head(5)

Unnamed: 0,patient1,patient2,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
423,patient-4074,patient-10139,4074,10139,0.304348,False,False,True
9670,patient-4074,patient-26572,4074,26572,0.3,True,True,True
16410,patient-90889,patient-84020,90889,84020,0.344828,True,False,True
17072,patient-90889,patient-67648,90889,67648,0.4,False,False,True
19221,patient-90889,patient-6138,90889,6138,0.333333,True,False,True


In [17]:
# rename column patient1 to source and patient2 to target
patient_similarity_df.rename(columns={'patient1': 'source_node', 'patient2': 'target_node'}, inplace=True)
patient_similarity_df.head(5)

Unnamed: 0,source_node,target_node,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
423,patient-4074,patient-10139,4074,10139,0.304348,False,False,True
9670,patient-4074,patient-26572,4074,26572,0.3,True,True,True
16410,patient-90889,patient-84020,90889,84020,0.344828,True,False,True
17072,patient-90889,patient-67648,90889,67648,0.4,False,False,True
19221,patient-90889,patient-6138,90889,6138,0.333333,True,False,True


# Create R-GCN and train the model

### DEBUG

In [36]:
# Get node types
node_types = {node_type: [] for node_type in ['patient', 'diagnosis', 'procedure']}
print(f"node_types: {node_types}")

    
# Create mappings from original node IDs to new indices for each node type
node_mappings = {node_type: {} for node_type in node_types}

print(f"node_mappings: {node_mappings}")
i=0
for node, attr in ntx_graph.nodes(data=True):
    node_type = attr['type']
    if i ==0:
        print(f"node: {node}")
        print(f"attr: {attr}")
        print(f"node_type: {node_type}")

    if node not in node_mappings[node_type]:
        node_mappings[node_type][node] = len(node_mappings[node_type])

        i+=1
        print(f"node_mappings[{node_type}][{node}]: {node_mappings[node_type][node]}")
        print(f"node_mappings[node_type][node]: {node_mappings[node_type][node]}")

    if i==3:
        break



node_types: {'patient': [], 'diagnosis': [], 'procedure': []}
node_mappings: {'patient': {}, 'diagnosis': {}, 'procedure': {}}
node: patient-4074
attr: {'gender': 'M', 'age_bucket': 80, 'hadm_id': 137421, 'type': 'patient'}
node_type: patient
node_mappings[patient][patient-4074]: 0
node_mappings[node_type][node]: 0
node_mappings[patient][patient-90889]: 1
node_mappings[node_type][node]: 1
node_mappings[patient][patient-72753]: 2
node_mappings[node_type][node]: 2


## Neural Network

In [None]:
# Convert NetworkX graph to PyTorch Geometric HeteroData
def convert_to_hetero_data(nx_graph):
    data = HeteroData()
    
    # create node types {'patient': [], 'diagnosis': [], 'procedure': []}
    node_types = {node_type: [] for node_type in ['patient', 'diagnosis', 'procedure']}

    
    # Create mappings from original node IDs to new indices for each node type
    # {'patient': {}, 'diagnosis': {}, 'procedure': {}}
    node_mappings = {node_type: {} for node_type in node_types}

    for node, attr in nx_graph.nodes(data=True):
        node_type = attr['type']
        if node not in node_mappings[node_type]:
            node_mappings[node_type][node] = len(node_mappings[node_type])
    
    # Add node features (just using one-hot encoding for simplicity)
    for node_type, mapping in node_mappings.items():
        num_nodes = len(mapping)
        # Use one-hot encoding as node features
        data[node_type].x = torch.eye(num_nodes)
    
    # Add edges
    edge_types = [('patient', 'has_diagnosis', 'diagnosis'), 
                  ('patient', 'has_procedure', 'procedure')]
    
    for src_type, relation, dst_type in edge_types:
        edge_indices = [[], []]
        for u, v, key, attr in nx_graph.edges(keys=True, data=True):
            if attr['relation'] == relation.replace('_', ' '):
                # Find source and destination nodes with correct types
                src_node_type = nx_graph.nodes[u]['type']
                dst_node_type = nx_graph.nodes[v]['type']
                
                if src_node_type == src_type and dst_node_type == dst_type:
                    edge_indices[0].append(node_mappings[src_type][u])
                    edge_indices[1].append(node_mappings[dst_type][v])
                elif src_node_type == dst_type and dst_node_type == src_type:
                    # In case the edge direction is reversed
                    edge_indices[0].append(node_mappings[dst_type][v])
                    edge_indices[1].append(node_mappings[src_type][u])
        
        if len(edge_indices[0]) > 0:
            data[src_type, relation, dst_type].edge_index = torch.tensor(edge_indices)
    
    return data, node_mappings

# Create RGCN model for contrastive learning
class RGCN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        
        # Initialize the GNN encoders
        self.conv1 = torch_geometric.nn.HeteroConv({
            ('patient', 'has_diagnosis', 'diagnosis'): GraphConv(-1, hidden_channels),
            ('patient', 'has_procedure', 'procedure'): GraphConv(-1, hidden_channels),
            ('diagnosis', 'rev_has_diagnosis', 'patient'): GraphConv(-1, hidden_channels),
            ('procedure', 'rev_has_procedure', 'patient'): GraphConv(-1, hidden_channels)
        })
        
        self.conv2 = torch_geometric.nn.HeteroConv({
            ('patient', 'has_diagnosis', 'diagnosis'): GraphConv(hidden_channels, hidden_channels),
            ('patient', 'has_procedure', 'procedure'): GraphConv(hidden_channels, hidden_channels),
            ('diagnosis', 'rev_has_diagnosis', 'patient'): GraphConv(hidden_channels, hidden_channels),
            ('procedure', 'rev_has_procedure', 'patient'): GraphConv(hidden_channels, hidden_channels)
        })
        
        # Projection head for contrastive learning
        self.patient_proj = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )
        
    def forward(self, x_dict, edge_index_dict):
        # First layer
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        
        # Second layer
        x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        
        # Project patient embeddings for contrastive learning
        patient_emb = self.patient_proj(x_dict['patient'])
        return patient_emb

# Create a contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, embeddings, pairs, labels):
        # Extract embeddings for pairs
        embeddings1 = embeddings[pairs[:, 0]]
        embeddings2 = embeddings[pairs[:, 1]]
        
        # Calculate Euclidean distance
        distances = F.pairwise_distance(embeddings1, embeddings2)
        
        # Calculate loss
        loss = (1 - labels) * torch.pow(distances, 2) + \
               labels * torch.pow(torch.clamp(self.margin - distances, min=0.0), 2)
        return loss.mean()

# Process NetworkX graph
print("Converting NetworkX graph to HeteroData...")
data, node_mappings = convert_to_hetero_data(ntx_graph)

# Add reverse edges to make the graph undirected
transform = T.ToUndirected()
data = transform(data)

# Find the node IDs in the patient_similarity_df
# Create a lookup from node name to node index
patient_node_lookup = {node: idx for node, idx in node_mappings['patient'].items()}

# Process similarity data for training
print("Processing similarity data...")
sim_pairs = []
sim_labels = []

# Process a subset of pairs for efficiency in this example
subset_df = patient_similarity_df.sample(min(50000, len(patient_similarity_df)))

for _, row in subset_df.iterrows():
    source = row['source_node']
    target = row['target_node']
    if source in patient_node_lookup and target in patient_node_lookup:
        source_idx = patient_node_lookup[source]
        target_idx = patient_node_lookup[target]
        sim_pairs.append([source_idx, target_idx])
        sim_labels.append(1.0 if row['is_similar'] else 0.0)

sim_pairs = torch.tensor(sim_pairs)
sim_labels = torch.tensor(sim_labels)

print(f"Created {len(sim_pairs)} training pairs")

# Create the model
hidden_channels = 64
out_channels = 32
model = RGCN(hidden_channels, out_channels, data.metadata())

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = ContrastiveLoss()

# Training loop
def train():
    model.train()
    optimizer.zero_grad()
    
    # Get embeddings
    embeddings = model(data.x_dict, data.edge_index_dict)
    
    # Compute loss
    loss = criterion(embeddings, sim_pairs, sim_labels)
    
    # Backpropagation
    loss.backward()
    optimizer.step()
    
    return loss.item()

# Train the model
print("Training model...")
num_epochs = 50
for epoch in range(num_epochs):
    loss = train()
    if (epoch + 1) % 5 == 0:
        print(f'Epoch: {epoch+1:02d}, Loss: {loss:.4f}')

Converting NetworkX graph to HeteroData...
Processing similarity data...
Created 50000 training pairs
Training model...


KeyError: "Tried to collect 'edge_index' but did not find any occurrences of it in any node and/or edge type"