In [1]:
import sys
sys.path.append("/home/tim/PycharmProjects/medical-lay/src")
import csv
from tqdm import tqdm
import numpy as np
from pathlib import Path
from scipy.spatial.distance import cdist
import json
from config import  TLCPaths



In [2]:
## functions
def get_batched_closest_index_idx(emb, index_embeddings):
  dist = cdist(emb, index_embeddings, metric="cosine")
  return np.argsort(dist,axis=-1)[:,:64]

def get_batched_closest_cuis(emb, index_embedding, index_cuis):
  closest_cui_idx = get_batched_closest_index_idx(emb, index_embedding)
  return np.array([[index_cuis[cui_idx] for cui_idx in single_closest_cui_idx] for single_closest_cui_idx in closest_cui_idx])

def top_k_acc(y_true, y_pred, k=1):
  y_pred = y_pred[:,:k]
  correct_preds = [true in pred for true,pred in zip(y_true,y_pred)]
  return np.mean(correct_preds)

In [3]:
# mention in sentence embedding
sentence_embedding_file = Path(TLCPaths.project_data_path /"embeddings/SAPBERT_XMLR_mention_with_sentence_embeddings.npy")
mention_token_file = Path(TLCPaths.project_data_path / "embeddings/SAPBERT_XMLR_mention_with_sentence_offsets.npy")
sentence_hidden_states = np.load(sentence_embedding_file)
mention_token_indices = np.load(mention_token_file, allow_pickle=True)

# only mention embedding
mention_embedding_file = Path(TLCPaths.project_data_path / "embeddings/SAPBERT_XMLR_large_mention_embedding.npy")
mention_hidden_states = np.load(mention_embedding_file)

# index embeddings
cls_index_file = Path(TLCPaths.project_data_path / "embeddings/SAPBERT_XMLR_large_umls_search_index_cls_token")
mean_all_index_file = Path(TLCPaths.project_data_path / "embeddings/SAPBERT_XMLR_large_umls_search_index_mean_token")
mean_no_cls_index_file = Path(TLCPaths.project_data_path / "embeddings/SAPBERT_XMLR_large_umls_search_index_mean_no_cls_token")
cls_token_index = np.concatenate([np.load(str(cls_index_file) + f"_{file_idx}.npy") for file_idx in range(5)])
mean_all_token_index = np.concatenate([np.load(str(mean_all_index_file) + f"_{file_idx}.npy") for file_idx in range(5)])
mean_no_cls_token_index = np.concatenate([np.load(str(mean_no_cls_index_file) + f"_{file_idx}.npy") for file_idx in range(5)])

concept_names, concept_cuis = [], []
with open(TLCPaths.project_data_path / "german_umls_names_and_cuis.csv", newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
    for row in reader:
      assert len(row)== 2
      concept_cuis.append(row[0])
      concept_names.append(row[1].strip())


with open(TLCPaths.project_data_path / "TLC_UMLS.json", "r") as f:
    data = json.load(f)
    X = [entry["mention"] for entry in data]
    Y = [entry["cui"] for entry in data]
    X_sent = [entry["mention_sentence"].strip() for entry in data]
    Y_sent = Y
    mention_offsets = [entry["mention_sentence_spans"] for entry in data]

In [4]:
sentence_cls_token_rep = sentence_hidden_states[:,0,:]

mention_token_reps = []
for i, token_indices in enumerate(mention_token_indices):
  if not token_indices:
    token_indices = range(150) # 5 token indices are empty ?!
  single_mention_token_reps = []
  for token_index in token_indices:
    single_mention_token_reps.append(sentence_hidden_states[i,token_index,:])
  mention_token_reps.append(single_mention_token_reps)

sentence_mention_token_rep = np.array([np.mean(emb,axis=0) for emb in mention_token_reps])

In [5]:
mention_cls_token_rep = mention_hidden_states[:,0,:]
mention_mean_token_rep = mention_hidden_states.mean(axis=1)
mention_no_cls_token_rep = mention_hidden_states[:,1:,:].mean(axis=1)

In [6]:
k_values = [1,2,4,8,16,32,64]

result_accuracies = {}
predictions = {}

# get predictions
1. mention with and without sentence context on cls token index

In [None]:
print("mention cls token on cls token index")
name = "SAPBERT_cls_token_on_cls_index"
Y_pred = get_batched_closest_cuis(mention_cls_token_rep, cls_token_index, index_cuis=concept_cuis)

mention cls token on cls token index


In [None]:
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
print("mention mean of all token on cls token index")
name = "SAPBERT_mean_token_on_cls_index"
Y_pred = get_batched_closest_cuis(mention_mean_token_rep, cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
print("mention mean of all except cls token on cls token index")
name = "SAPBERT_mean_no_cls_token_on_cls_index"
Y_pred = get_batched_closest_cuis(mention_no_cls_token_rep, cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
print("sentence with mention cls token on cls token index")
name = "SAPBERT_sent_cls_token_on_cls_index"
Y_pred = get_batched_closest_cuis(sentence_cls_token_rep, cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
print("mean mention in sentence token on cls token index")
name = "SAPBERT_sent_mean_token_on_cls_index"
Y_pred = get_batched_closest_cuis(sentence_mention_token_rep, cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
del cls_token_index

2. mention with and without sentence context on mean token index

In [None]:
print("mention cls token on mean token index")
name = "SAPBERT_cls_token_on_mean_index"
Y_pred = get_batched_closest_cuis(mention_cls_token_rep, mean_all_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("mention mean of all token on mean token index")
name = "SAPBERT_mean_all_token_on_mean_index"
Y_pred = get_batched_closest_cuis(mention_mean_token_rep, mean_all_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("mention mean of all except cls token on mean token index")
name = "SAPBERT_mean_no_cls_token_on_mean_index"
Y_pred = get_batched_closest_cuis(mention_no_cls_token_rep, mean_all_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("sentence with mention cls token on mean token index")
name = "SAPBERT_sent_cls_token_on_mean_index"
Y_pred = get_batched_closest_cuis(sentence_cls_token_rep, mean_all_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("mean mention in sentence token on mean token index")
name = "SAPBERT_sent_mean_mention_token_on_mean_index"
Y_pred = get_batched_closest_cuis(sentence_mention_token_rep, mean_all_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
del mean_all_token_index

In [None]:
print("mention cls token on mean no cls token index")
name = "SAPBERT_cls_token_on_mean_no_cls_index"
Y_pred = get_batched_closest_cuis(mention_cls_token_rep, mean_no_cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("mention mean of all token on mean _no_cls token index")
name = "SAPBERT_mean_all_token_on_mean_no_cls_index"
Y_pred = get_batched_closest_cuis(mention_mean_token_rep, mean_no_cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("mention mean of all except cls token on mean _no_cls token index")
name = "SAPBERT_mean_no_cls_token_on_mean_no_cls_index"
Y_pred = get_batched_closest_cuis(mention_no_cls_token_rep, mean_no_cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("sentence with mention cls token on mean _no_cls token index")
name = "SAPBERT_sent_cls_token_on_mean_no_cls_index"
Y_pred = get_batched_closest_cuis(sentence_cls_token_rep, mean_no_cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

print("mean mention in sentence token on mean _no_cls token index")
name = "SAPBERT_sent_mean_mention_token_on_mean_no_cls_index"
Y_pred = get_batched_closest_cuis(sentence_mention_token_rep, mean_no_cls_token_index, index_cuis=concept_cuis)
accs = [round(top_k_acc(Y,Y_pred, k=i),3) for i in k_values]
_ = [print(f"Acc@{i}: {acc}") for i, acc in zip(k_values,accs)]
result_accuracies[name] = accs
predictions[name] = Y_pred

In [None]:
accuracies_file = Path(TLCPaths.project_data_path / "sapbert_accuracies.json")
predictions_file = Path(TLCPaths.project_data_path / "sapbert_predictions.json")
import json
with open(accuracies_file, "w") as f:
  json.dump(result_accuracies, f)
with open(predictions_file, "w") as f:
  json.dump(predictions, f)
