# Link prediction with GNN

## Define model classes

In [1]:
%load_ext autoreload
%autoreload 2  

import dgl 
import torch 
import pandas as pd 
import pickle 
import pytorch_lightning as pl
import random
from pytorch_lightning.loggers import CSVLogger
from tqdm.notebook import tqdm

In [2]:
def open_pickle(f):
    with open(f, 'rb') as fname:
        node_dict = pickle.load(fname)
    return node_dict

def to_pickle(node_dict, f):
    with open(f, 'wb') as fname:
        pickle.dump(node_dict, fname)

In [3]:
import dgl 
import torch 
import pytorch_lightning as pl
import gc
import torchmetrics
import torch.nn as nn 
import torch.nn.functional as F

# set seeds
torch.manual_seed(0)
import random
random.seed(0)
import numpy as np
np.random.seed(0)

"""
    config_dict: in_feat, out_feat, head_size_1, head_size_2, num_heads_1, num_heads_2, dropout
"""
class mini_hgt(pl.LightningModule):
    def __init__(self, config_dict):
        super().__init__()

        self.num_layers = config_dict['num_layers']
        in_feat = config_dict['in_feat']
        num_heads = 1
        out_feat = config_dict['out_feat']
                            
        self.convs = torch.nn.ModuleList()

        for i in range(self.num_layers):
            in_dim = in_feat if i == 0 else (in_feat * num_heads)
            conv = dgl.nn.pytorch.conv.SAGEConv(in_dim, out_feat, "pool")
            self.convs.append(conv)
        self.linear = torch.nn.Linear(in_feat * num_heads, out_feat)
        self.relu = torch.nn.ReLU()

    def forward(self, blocks, x):
        for i in range(self.num_layers):
            b = blocks[i]
            x = self.convs[i](b, x)

            # if not last layer
            if i < self.num_layers - 1:
                x = self.relu(x)

        x = self.linear(x)
        
        return x


In [4]:
class EdgePredModel(pl.LightningModule):
    def __init__(self, homo_hg, hgt_config):
        super().__init__()

        self.lr = hgt_config['lr']
        self.homo_hg = homo_hg
        self.gnn = mini_hgt(hgt_config)

        self.accuracy = torchmetrics.classification.Accuracy(task='binary')
        self.softmax = nn.Softmax(dim=1)
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.margin_loss = nn.TripletMarginWithDistanceLoss(distance_function=self.cos, margin=0.2)

        self.h_dict = {}
                 
    def forward(self, input_nodes, blocks):

        x_graph = blocks[0].srcdata['feat']

        # returns dictionary of embeddings
        h = self.gnn(blocks, x_graph)

        for ind, i in zip(input_nodes.tolist(), range(0, h.shape[0])):
            self.h_dict[ind] = h[i]
            
                    
    def training_step(self, train_batch, _):
        
        edge_input_nodes, edge_graph, edge_blocks, edge_batch_samples = train_batch
        
        # edges
        self.forward(edge_input_nodes, edge_blocks)
        edge_loss_score, acc = self.edge_loss(edge_batch_samples)

        self.log('train_loss_edge', edge_loss_score, prog_bar=True, batch_size=len(edge_batch_samples))
        self.log('train_acc_edge', acc, prog_bar=True, batch_size=len(edge_batch_samples)) 
        
        # reset embeddings dict
        self.h_dict = {}
            
        return edge_loss_score
        
    def edge_loss(self, batch_samples):
        pos_list = []
        neg_list = []
        anch_list = []

        for s in batch_samples:
            d_anch = s[0]
            d_pos = s[1]
            d_neg = s[2]
            anch_embed = self.h_dict[d_anch[0].item()]
            pos_embed = self.h_dict[d_pos[0].item()]
            neg_embed = self.h_dict[d_neg[0].item()] 
            
            anch_list.append(anch_embed)
            pos_list.append(pos_embed)
            neg_list.append(neg_embed)       

        anch_vec = torch.stack(anch_list)
        pos_vec = torch.stack(pos_list)
        neg_vec = torch.stack(neg_list)

        loss = self.margin_loss(anch_vec, pos_vec, neg_vec)

        d_pos = self.cos(pos_vec, anch_vec)
        d_neg = self.cos(neg_vec, anch_vec)
        correct = (d_pos < d_neg).sum().item()
        acc = correct/len(anch_vec)

        return loss, acc
    
    def configure_optimizers(self):
        param_list = [{'params': self.gnn.parameters()}]
        optimizer = torch.optim.Adam(param_list, lr=self.lr)
        #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5, verbose=True)
        return [optimizer] #, [lr_scheduler]

## Model setup

In [5]:
#
# Set up model params
#
print("Assembling graph")
node_df = pd.read_csv("/n/holylfs06/LABS/mzitnik_lab/Lab/ruthjohnson/kg_paper_revision/connected_node_logml_df.csv", sep='\t')
edge_df = pd.read_csv("/n/holylfs06/LABS/mzitnik_lab/Lab/ruthjohnson/kg_paper_revision/connected_edge_logml_df.csv", sep='\t')

u = torch.tensor(edge_df['node_index_x'].tolist())
v = torch.tensor(edge_df['node_index_y'].tolist())

g = dgl.graph((u,v))
graph_feature_df = pd.read_csv("/n/holylfs06/LABS/mzitnik_lab/Lab/ruthjohnson/kg_paper_revision/graph_feature_logml_df.csv")
g.ndata['feat'] = torch.tensor(graph_feature_df.values, dtype=torch.float32)

ntype_list = node_df['ntype'].unique()
ntype_dict = {}
ntype_index_dict = {}
i = 0
for t in ntype_list:
    ntype_dict[t] = i
    ntype_index_dict[i] = t
    i+=1 

ntype_index_dict = {}
ntype_index_dict['ATC'] = 0
ntype_index_dict['ICD10CM'] = 2
ntype_index_dict['LNC'] = 1
ntype_index_dict['PHECODE'] = 2
ntype_index_dict['RXNORM'] = 0
ntype_index_dict['SNOMEDCT_US'] = 3
ntype_index_dict['UMLS_CUI'] = 3

etypes = edge_df['ntype_x'] + ':' + edge_df['ntype_y']
etype_list = etypes.unique()
etype_dict = {}
i = 0 
for t in etype_list:
    if t.split(':')[1] in ['ATC', 'PHECODE', 'CPT']:
        etype_dict[t] = 1
    else:
        etype_dict[t] = 0

node_df['ntype_index'] = node_df['ntype'].map(ntype_index_dict)
g.ndata['ntype'] = torch.tensor(node_df['ntype_index'].tolist(), dtype=torch.int32)
g.edata['etype'] = torch.tensor((edge_df['ntype_x'] + ':' + edge_df['ntype_y']).map(etype_dict), dtype=torch.int32)

Assembling graph


In [6]:
"""
# Create a mapping for quick lookups
source_target_pairs = edge_df[['node_index_x', 'node_index_y']].apply(tuple, axis=1)
reverse_pairs = edge_df[['node_index_y', 'node_index_x']].apply(tuple, axis=1)

# Create a dictionary for reverse edges
reverse_map = dict(zip(source_target_pairs, edge_df['edge_index']))

# Assign reverse edges using the reverse_pairs
edge_df['reverse_edge_index'] = reverse_pairs.map(reverse_map)
rev_id_dict = edge_df.set_index('edge_index')['reverse_edge_index'].to_dict()
to_pickle(rev_id_dict, "rev_edge_dict_logml.pkl")
"""
rev_id_dict = open_pickle("/n/holylfs06/LABS/mzitnik_lab/Lab/ruthjohnson/kg_paper_revision/model/rev_edge_dict_logml.pkl")

In [7]:
n_epochs=10

config = {
    'num_layers': 2,
    'n_neg': 1,
    'batch_size': 500,
    'sampler_n': 20, 
    'lr': 1e-4,
    'in_feat': 128,
    'out_feat': 128,
    'head_size': 512,
    'num_heads': 3, 
}
device = "cpu"

logger = CSVLogger("logs", name="my_exp_name")

trainer = pl.Trainer(max_epochs=n_epochs,
                        log_every_n_steps=1,
                        #precision="bf16-mixed", #"bf16-mixed"
                        accelerator=device,
                        logger=logger
                        )
from dataloader import edge_pred_dataloader

model = EdgePredModel(g, config)

data_module = edge_pred_dataloader(homo_hg=g, homo_hg_dict=config, rev_edge_dict=rev_id_dict)

#trainer.fit(model=model, datamodule=data_module)


/n/home01/ruthjohnson/.local/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /n/home01/ruthjohnson/venv_dgl/lib/python3.10/site-p ...
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/n/home01/ruthjohnson/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [118]:
torch.save(model.gnn.state_dict(), "logml_model.pt")

## Final embedding generation

In [8]:
state_dict = torch.load('/n/holylfs06/LABS/mzitnik_lab/Lab/ruthjohnson/kg_paper_revision/model/logml_model.pt')
model.gnn.load_state_dict(state_dict)
model.eval()

EdgePredModel(
  (gnn): mini_hgt(
    (convs): ModuleList(
      (0-1): 2 x SAGEConv(
        (feat_drop): Dropout(p=0.0, inplace=False)
        (fc_pool): Linear(in_features=128, out_features=128, bias=True)
        (fc_neigh): Linear(in_features=128, out_features=128, bias=False)
        (fc_self): Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (linear): Linear(in_features=128, out_features=128, bias=True)
    (relu): ReLU()
  )
  (accuracy): BinaryAccuracy()
  (softmax): Softmax(dim=1)
  (cos): CosineSimilarity()
  (margin_loss): TripletMarginWithDistanceLoss(
    (distance_function): CosineSimilarity()
  )
)

In [22]:
eval_list = node_df.loc[node_df['ntype'] == 'ICD10CM']['node_index'].tolist()

sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10])

h_list = []

with torch.no_grad():
    for node_ind in tqdm(eval_list):

        dataloader = dgl.dataloading.DataLoader(
            model.homo_hg,
            torch.tensor([node_ind]),     # compute embeddings for all the nodes
            sampler,
            shuffle=False,    # remember to set this to False so you can just concatenate embeddings at the end
            batch_size=1,
            device="cpu"
        )

        for inputs, graph, blocks in dataloader:
            inputs = blocks[0].srcdata['feat']
            graph_h = model.gnn(blocks, inputs)
            h_list.append(graph_h)

    all_h = torch.concat(h_list)

  0%|          | 0/11933 [00:00<?, ?it/s]



In [24]:
to_pickle(all_h, "logml_icd_embeds.pkl")

## Visualize embeddings

In [None]:
import umap

reducer = umap.UMAP(n_neighbors=50, min_dist=0.05, metric='cosine')
embedding_standard = reducer.fit_transform(all_h.numpy())
embed_df = pd.DataFrame(embedding_standard)
embed_df['node_index'] = eval_list
embed_df = embed_df.merge(node_df, on='node_index')
embed_df['cat'] = embed_df['node_id'].str.slice(0,1)

In [33]:
import plotly.express as px  

fig_2d = px.scatter(
    embed_df, x=0, y=1, color='cat',
hover_data=["node_id", "node_name"]
)

# hover name
fig_2d.update_layout(
    width=500,
    height=400,
)

fig_2d.update_traces(opacity=.7)
fig_2d.show()