In [1]:
import csv

import numpy as np
import pysolr
from tqdm import tqdm
import json

from config import TLCPaths
from preprocessing import cistem

In [2]:
solr = pysolr.Solr('http://localhost:8983/solr/wumls-single-valued-deduplicated',
                   always_commit=False)
# solr = pysolr.Solr('http://localhost:8983/solr/wumls-multi-valued', always_commit=False)
solr.ping()

'{\n  "responseHeader":{\n    "zkConnected":null,\n    "status":0,\n    "QTime":80,\n    "params":{\n      "q":"{!lucene}*:*",\n      "distrib":"false",\n      "df":"_text_",\n      "rows":"10",\n      "echoParams":"all",\n      "rid":"localhost-0"}},\n  "status":"OK"}\n'

In [3]:
def create_query_with_edit_distance(term):
    return "{!func}strdist('" + term + "',index_term,edit)"


In [4]:
search_term = "pyelonephriti"
results = solr.search(q=create_query_with_edit_distance(term=search_term), fl="*,score", rows=10)
for res in results:
    print(res)

{'cui': ['C0034186'], 'source': ['MDRGER'], 'language': ['GER'], 'name': ['Pyelonephritis'], 'index_term': ['pyelonephriti'], 'id': 'c8b1ada1-826d-46fa-a6fb-186add6196a1', '_version_': 1760194257981800490, 'score': 1.0}
{'cui': ['C0034188'], 'source': ['MDRGER'], 'language': ['GER'], 'name': ['Xanthogranulomatoese Pyelonephritis'], 'index_term': ['xanthogranulomato pyelonephriti'], 'id': 'c01e43d7-a3e5-482d-800a-73b49306f2b7', '_version_': 1760194257982849026, 'score': 1.0}
{'cui': ['C1697444'], 'source': ['MDRGER'], 'language': ['GER'], 'name': ['virale Pyelonephritis'], 'index_term': ['viral pyelonephriti'], 'id': 'c1a048cc-3de9-4231-9820-688e08951389', '_version_': 1760194263649353728, 'score': 1.0}
{'cui': ['C0034188'], 'source': ['MSHGER'], 'language': ['GER'], 'name': ['Pyelonephritis, xanthogranulomatöse'], 'index_term': ['pyelonephritis, xanthogranulomato'], 'id': '1729c369-87a4-4b8a-b12c-1953689789be', '_version_': 1760194257982849025, 'score': 0.9285714}
{'cui': ['C1328529'],

In [5]:

with open(TLCPaths.project_data_path / "TLC_UMLS.json", "r") as f:
    data = json.load(f)
    X = [entry["mention"] for entry in data]
    Y = [entry["cui"] for entry in data]
    X_sent = [entry["mention_sentence"].strip() for entry in data]
    Y_sent = Y
    mention_offsets = [entry["mention_sentence_spans"] for entry in data]
print(len(X))  #%%


6228


In [6]:
y_pred = []
y_pred_names = []
for mention in tqdm(X):
    stemmed_mention = cistem.stem(mention)
    results = solr.search(q=create_query_with_edit_distance(term=stemmed_mention), fl="*,score",
                          rows=10)
    found_cuis = [res["cui"] for res in results]
    found_names = [res["name"] for res in results]
    y_pred.append(found_cuis)
    y_pred_names.append(found_names)

100%|██████████| 6228/6228 [06:11<00:00, 16.78it/s]  


In [7]:
y_pred = np.array(y_pred)
y_pred_names = np.array(y_pred_names)

In [8]:
def top_k_acc(y_true, y_pred, k=1):
    y_pred = y_pred[:, :k]
    correct_preds = [true in pred for true, pred in zip(y_true, y_pred)]
    return np.mean(correct_preds)


accs = []
k_values = list(range(1, 11)) + [16, 32, 64]
for i in k_values:
    acc = round(top_k_acc(Y, y_pred, k=i), 3)
    accs.append(acc)
    print(f"Acc@{i}: ", acc)

Acc@1:  0.492
Acc@2:  0.503
Acc@3:  0.539
Acc@4:  0.576
Acc@5:  0.578
Acc@6:  0.578
Acc@7:  0.579
Acc@8:  0.583
Acc@9:  0.594
Acc@10:  0.595
Acc@16:  0.595
Acc@32:  0.595
Acc@64:  0.595


In [11]:
print(accs)

[0.492, 0.503, 0.539, 0.576, 0.578, 0.578, 0.579, 0.583, 0.594, 0.595, 0.595, 0.595, 0.595]


In [12]:
import time

timestr = time.strftime("%Y%m%d-%H%M%S")

rows = []
for mention, name, true_cui, pred_cui in zip(X, np.array(y_pred_names)[:, 0], Y, y_pred[:, 0]):
    rows.append([mention, name[0], true_cui, pred_cui[0], true_cui == pred_cui[0]])
rows = sorted(rows, key=lambda x: x[-1])

with open(TLCPaths.project_data_path.joinpath(f"german_tlc_prediction_SOLR_{timestr}.csv"), "w",
          newline="") as f:
    spamwriter = csv.writer(f, delimiter=' ',
                            quotechar='|', quoting=csv.QUOTE_MINIMAL)
    spamwriter.writerow(["mention", "Predicted Concept", "true cui", "pred cui", "Correct"])
    for row in rows:
        spamwriter.writerow(row)

In [None]:
# interesting samples
# bauchspeicheldrüse |Künstliche Bauchspeicheldrüse| C0183516 C0336563 False
# Blähungen Blähung C1291077 C0016204 False
# Blähen blase C0016204 C0005758 False