# Imports

In [23]:
import pickle
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
# import torch geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.data import Data, DataLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GraphConv
# from torch_geometric.nn import MessagePassing
# from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
# from torch_geometric.utils import degree, to_dense_adj, to_dense_batch, to_undirected
# from torch_geometric.utils import from_networkx, to_networkx
# from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
# from torch_geometric.utils import dropout_adj, to_undirected, add_self_loops
# from torch_geometric.utils import to_dense_adj, to_dense_batch, dense_to_sparse

# print versions
print("torch version: ", torch.__version__)
print("torch geometric version: ", torch_geometric.__version__)
print("torch cuda version: ", torch.version.cuda)
print("torch cuda available: ", torch.cuda.is_available())
# print("torch cuda device count: ", torch.cuda.device_count())
# print("torch cuda current device name: ", torch.cuda.get_device_name(torch.cuda.current_device()))

# print versions
print(f"Pandas version: {pd.__version__}")
print(f"Numpy version: {np.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"NetworkX version: {nx.__version__}")

# set pandas display options to show all columns and rows without truncation
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
print("\nRemoved truncation of columns")


torch version:  2.6.0+cpu
torch geometric version:  2.6.1
torch cuda version:  None
torch cuda available:  False
Pandas version: 2.2.3
Numpy version: 2.2.4
Matplotlib version: 3.10.1
NetworkX version: 3.4.2

Removed truncation of columns


# Load data

## Load networkx graph

In [None]:
# load the networkx graph
def load_graph(graph_path):
    with open(graph_path, 'rb') as f:
        graph = pickle.load(f)
    return graph

CURR_DIR_PATH = os.getcwd()
PICKLE_GRAPH_FILE_NAME = f"{CURR_DIR_PATH}\\pickle\\max_20000_nodes_graph.gpickle"
PICKLE_GRAPH_FILE_NAME = os.path.abspath(PICKLE_GRAPH_FILE_NAME)

print(f"Loading graph from {PICKLE_GRAPH_FILE_NAME}...")
ntx_graph = load_graph(PICKLE_GRAPH_FILE_NAME)
print(f"Graph loaded. Number of nodes: {ntx_graph.number_of_nodes()}, number of edges: {ntx_graph.number_of_edges()}")


Loading graph from d:\Repos\ut-health-final-proj\pickle\max_20000_nodes_graph.gpickle...
Graph loaded. Number of nodes: 17654, number of edges: 172728


In [16]:
# count nodes of type patient, diagnosis and procedure
def count_node_types(graph):
    node_types = {}
    for node in graph.nodes():
        node_type = graph.nodes[node]['type']
        if node_type not in node_types:
            node_types[node_type] = 0
        node_types[node_type] += 1
    return node_types

node_types = count_node_types(ntx_graph)
print(f"Node counts: {node_types}")

# count edges named had_procedure, has_diagnosis
def count_edge_types(graph):
    edge_types = {}
    for u, v, key in graph.edges(keys=True):
        edge_type = graph.edges[u, v, key]['relation']
        if edge_type not in edge_types:
            edge_types[edge_type] = 0
        edge_types[edge_type] += 1
    return edge_types

edge_types = count_edge_types(ntx_graph)
print(f"Edge counts: {edge_types}")

Node counts: {'patient': 11630, 'diagnosis': 4680, 'procedure': 1344}
Edge counts: {'has_diagnosis': 123987, 'has_procedure': 48741}


## Load patient node similarity dataframe

In [8]:
# load the patient similarity dataframe
def load_patient_similarity_df(df_path):
    with open(df_path, 'rb') as f:
        df = pickle.load(f)
    return df

CURR_DIR_PATH = os.getcwd()
PATIENT_SIMILARITY_DF_FILE_NAME = f"{CURR_DIR_PATH}\\pickle\\max_20000_nodes_similarity.gpickle"
PATIENT_SIMILARITY_DF_FILE_NAME = os.path.abspath(PATIENT_SIMILARITY_DF_FILE_NAME)

print(f"Loading patient similarity dataframe from {PATIENT_SIMILARITY_DF_FILE_NAME}...")
patient_similarity_df = load_patient_similarity_df(PATIENT_SIMILARITY_DF_FILE_NAME)
print(f"Patient similarity dataframe loaded. Number of rows: {patient_similarity_df.shape[0]}")

Loading patient similarity dataframe from d:\Repos\ut-health-final-proj\pickle\max_20000_nodes_similarity.gpickle...
Patient similarity dataframe loaded. Number of rows: 1288402


In [9]:
patient_similarity_df.head(5)

Unnamed: 0,patient1,patient2,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
423,patient-4074,patient-10139,4074,10139,0.304348,False,False,True
9670,patient-4074,patient-26572,4074,26572,0.3,True,True,True
16410,patient-90889,patient-84020,90889,84020,0.344828,True,False,True
17072,patient-90889,patient-67648,90889,67648,0.4,False,False,True
19221,patient-90889,patient-6138,90889,6138,0.333333,True,False,True


In [17]:
# rename column patient1 to source and patient2 to target
patient_similarity_df.rename(columns={'patient1': 'source_node', 'patient2': 'target_node'}, inplace=True)
patient_similarity_df.head(5)

Unnamed: 0,source_node,target_node,patient1_id,patient2_id,jaccard_similarity,same_gender,same_age_bucket,is_similar
423,patient-4074,patient-10139,4074,10139,0.304348,False,False,True
9670,patient-4074,patient-26572,4074,26572,0.3,True,True,True
16410,patient-90889,patient-84020,90889,84020,0.344828,True,False,True
17072,patient-90889,patient-67648,90889,67648,0.4,False,False,True
19221,patient-90889,patient-6138,90889,6138,0.333333,True,False,True
