In [31]:
import os
import csv
import json
import numpy as np
import pandas as pd

from tqdm import tqdm
from glob import glob
from bs4 import BeautifulSoup
from collections import defaultdict

from sklearn.metrics import confusion_matrix
from common_string import longest_common_substring_percentage

In [81]:
def eval_metrics(y_true, y_pred, flip_labels=False):
    if flip_labels:
        y_true = np.abs(np.array(y_true)-2)-1
        y_pred = np.abs(np.array(y_pred)-2)-1
    
    tn, fp, fn, tp = map(float, confusion_matrix(y_true, y_pred).ravel())
    specificity = tn/(tn+fp)
    sensitivity = tp/(tp+fn)
    fpr = fp/(fp+tn)
    precision = tp/(tp+fp)
    f1 = tp/(tp+0.5*(fp+fn))

    return {
        'flip_labels': flip_labels,
        'TN': tn,
        'FP': fp,
        'FN': fn,
        'TP': tp,
        'specificity': specificity,
        'sensitivity': sensitivity,
        'fpr': fpr,
        'precision': precision,
        'F1': f1
    }


In [4]:
umls_file = 'data/umls_meddra_en.csv'
fh = open(umls_file)
reader = csv.reader(fh)
header = next(reader)

meddra_terms = set()
meddra_code2term = dict()
for row in reader:
    d = dict(zip(header, row))
    meddra_terms.add(d['STR'].lower())
    meddra_code2term[int(d['CODE'])] = d['STR'].lower()

fh.close()

In [5]:
# load the testing set
folder = 'data/TAC2017/'

test_labels = glob(folder+'gold_xml/*')

drug2mentions = defaultdict(set)
drug2reactions = defaultdict(set)

for label in tqdm(test_labels):
    drug_name = label.split('/')[-1].split('.')[0]
    with open(label, 'r') as f:
        soup = BeautifulSoup(f, 'xml')
    
    for mention in soup.find_all('Mention'):
        if mention['type'] == 'AdverseReaction':
          section_name = mention['section']
          if section_name != 'S1':
              continue
            
          mention_str = mention['str'].lower()
          drug2mentions[drug_name].add(mention_str)
    
    for reaction in soup.find_all('Reaction'):
        reaction_str = reaction['str']
        drug2reactions[drug_name].add(reaction_str)

len(drug2mentions), len(drug2reactions)

  0%|          | 0/99 [00:00<?, ?it/s]

100%|██████████| 99/99 [00:00<00:00, 217.78it/s]


(99, 99)

In [6]:
diffs_list = list()

task3ref = list()

for drug in drug2mentions.keys():

    for rxn in (drug2mentions[drug] & drug2reactions[drug]):
        task3ref.append([drug, rxn, 1])
    
    for rxn in (drug2mentions[drug] - drug2reactions[drug]):
        task3ref.append([drug, rxn, 0])

    setdiff = drug2mentions[drug]-drug2reactions[drug]
    diff = len(drug2mentions[drug])-len(drug2reactions[drug])
    setdiff_inmeddra = meddra_terms & drug2mentions[drug]
    diffs_list.append([drug, len(drug2mentions[drug]), len(drug2reactions[drug]), diff, len(setdiff), len(setdiff_inmeddra)])

diffs = pd.DataFrame(diffs_list, columns=['drug', 'nmentions', 'nreactions', 'diff', 'setdiff', 'nmeddraexact'])
diffs.shape, len(task3ref)

((99, 6), 4743)

In [82]:
metrics_fn = 'results/task3/evaluation_metrics.json'
metrics = None
if os.path.exists(metrics_fn):
    fh = open(metrics_fn)
    metrics = json.load(fh)
    fh.close()
else:
    metrics = dict()

def save_metrics(metrics):
    fh = open(metrics_fn, 'w')
    json.dump(metrics, fh, indent=4)
    fh.close()

save_metrics(metrics)

## OnSIDES BERT

In [72]:
# load onsides from the best model and evaluated on the testing set
pred_filename = 'data/task3/grouped-mean-final-bydrug-PMB_14-AR-125-all_222_TAC_25_2.5e-05_256_32.csv'
ob_pred = pd.read_csv(pred_filename, index_col=0)
events = list()

for _, row in ob_pred.iterrows():
    events.append(meddra_code2term[row['pt_meddra_id']])

ob_pred.insert(3, "event", events)

# from releases.json file in onsides
threshold = 0.4633
ob_predictions = list()

# build prediction list
for drug, rxn, label in tqdm(task3ref):
    
    if rxn.find('"') != -1:
        querystr = """drug == '{}' & event == '{}' """.format(drug, rxn)
    elif rxn.find("'") != -1:
        querystr = """drug == "{}" & event == "{}" """.format(drug, rxn)
    else:
        querystr = "drug == '{}' & event == '{}'".format(drug, rxn)    
    
    p = ob_pred.query(querystr)

    # NOTE: leniency is irrelevant here because OnSIDES-BERT only considers 
    # NOTE: terms that are exact matches from the label. So each term from 
    # NOTE: OnSIDES BERT must be present in the reference as mentioned.

    if p.shape[0] == 0:
        # not an exact match or not scored by OnsidesBERT
        pred1 = 0.0
    else:
        pred1 = float(p['Pred1'])
    
    if pred1 >= threshold:
        ob_predictions.append(1)
    else:
        ob_predictions.append(0)
    
len(ob_predictions), sum(ob_predictions), len(task3ref)

100%|██████████| 4743/4743 [00:03<00:00, 1321.54it/s]


(4743, 2693, 4743)

In [83]:
_, _, labels = zip(*task3ref)
em = eval_metrics(labels, ob_predictions, flip_labels=True)

if not pred_filename in metrics:
    metrics[pred_filename] = em

save_metrics(metrics)

## DeepCADRME

In [76]:
d_pred_filename = "results/extract/deepcadrme_100_test.csv"
d_pred = pd.read_csv(d_pred_filename, index_col=0).query("section_name == 'adverse reactions'")
d_pred

Unnamed: 0,drug_name,section_name,gpt_output
0,IMPAVIDO,adverse reactions,"nausea, vomiting, diarrhea, headache, decrease..."
3,LIVALO,adverse reactions,"rhabdomyolysis, myoglobinuria, acute renal fai..."
5,XENAZINE,adverse reactions,"depression, suicidality, akathisia, restlessne..."
8,LINZESS,adverse reactions,"diarrhea, abdominal pain, flatulence, abdomina..."
11,OPSUMIT,adverse reactions,"embryo fetal toxicity, hepatotoxicity, decreas..."
...,...,...,...
223,AUBAGIO,adverse reactions,"hepatotoxicity, bone marrow effects, immunosup..."
226,POMALYST,adverse reactions,"fetal risk, venous, arterial thromboembolism, ..."
229,SURFAXIN,adverse reactions,"endotracheal tube reflux, pallor, endotracheal..."
231,ARZERRA,adverse reactions,"infusion reactions, hepatitis b virus reactiva..."


In [22]:
d_predictions = list()

for drug, rxn, label in tqdm(task3ref):

    # extractions = str(d_pred.query(f"drug_name == '{drug}'")['gpt_output'])
    extractions = list(d_pred.query(f"drug_name == '{drug}'")['gpt_output'].str.split(', '))[0]

    # strict
    # if rxn in extractions:
    # lenient
    if any([longest_common_substring_percentage(rxn, x) > 0.8 for x in extractions]):
        d_predictions.append(1)
    else:
        d_predictions.append(0)

len(d_predictions), sum(d_predictions), len(task3ref)

100%|██████████| 4743/4743 [00:05<00:00, 821.12it/s] 


(4743, 4560, 4743)

In [85]:
_, _, labels = zip(*task3ref)
em = eval_metrics(labels, d_predictions, flip_labels=True)

if not d_pred_filename in metrics:
    metrics[d_pred_filename] = em

save_metrics(metrics)

## Onsides LLM

In [95]:
test_runs = [f for f in os.listdir('results/extract') if f.find('_test_') != -1]

for runfile in tqdm(test_runs):
    ol_pred_fn = os.path.join('results', 'extract', runfile)
    
    if ol_pred_fn in metrics:
        continue

    ol_pred = pd.read_csv(ol_pred_fn, index_col=0).query("section_name == 'adverse reactions'")
    ol_predictions = list()

    for drug, rxn, label in task3ref:

        extractions = list(ol_pred.query(f"drug_name == '{drug}'")['gpt_output'].str.split(', '))[0]

        # strict
        # if rxn in extractions:
        # lenient
        if any([longest_common_substring_percentage(rxn, x) > 0.8 for x in extractions]):
            ol_predictions.append(1)
        else:
            ol_predictions.append(0)

    _, _, labels = zip(*task3ref)
    em = eval_metrics(labels, ol_predictions, flip_labels=True)

    if not ol_pred_fn in metrics:
        metrics[ol_pred_fn] = em

    save_metrics(metrics)


 25%|██▌       | 2/8 [00:07<00:21,  3.65s/it]

In [87]:
# # ol_pred_fn = "results/extract/OpenAI_gpt-4-1106-preview_fatal-prompt-v2_pharmexpert-v1_temp0_test_run0.csv"
# # ol_pred_fn = 'results/extract/OpenAI_gpt-4-1106-preview_only-positives-v0_pharmexpert-v0_temp0_test_run0.csv'
# ol_pred_fn = "results/extract/OpenAI_gpt-4-1106-preview_gpt-written-prompt_pharmexpert-v0_temp0_test_run0.csv"
# ol_pred = pd.read_csv(ol_pred_fn, index_col=0).query("section_name == 'adverse reactions'")

In [88]:
# ol_predictions = list()

# for drug, rxn, label in tqdm(task3ref):

#     extractions = list(ol_pred.query(f"drug_name == '{drug}'")['gpt_output'].str.split(', '))[0]

#     # strict
#     # if rxn in extractions:
#     # lenient
#     if any([longest_common_substring_percentage(rxn, x) > 0.8 for x in extractions]):
#         ol_predictions.append(1)
#     else:
#         ol_predictions.append(0)

# len(ol_predictions), sum(ol_predictions), len(task3ref)

100%|██████████| 4743/4743 [00:03<00:00, 1207.15it/s]


(4743, 3642, 4743)

In [90]:
# _, _, labels = zip(*task3ref)
# em = eval_metrics(labels, ol_predictions, flip_labels=True)

# if not ol_pred_fn in metrics:
#     metrics[ol_pred_fn] = em

# save_metrics(metrics)

## Evaluation Snapshot

In [19]:
# compile results
d, e, l = zip(*task3ref)
df_data = zip(d, e, l, ob_predictions, d_predictions, ol_predictions)

predictions = pd.DataFrame(df_data, columns=["drug", "event", "label", "OB", "D", "OL"])

# flip all the labels
flip_labels = True
if flip_labels:
    for colname in ('label', 'OB', 'D', 'OL'):
        predictions[colname] = np.abs(predictions[colname]-2)-1

predictions

Unnamed: 0,drug,event,label,OB,D,OL
0,IMPAVIDO,flatulence,0,0,0,0
1,IMPAVIDO,arthritis,0,0,0,0
2,IMPAVIDO,"platelet count < 150,000",0,1,1,1
3,IMPAVIDO,somnolence,0,0,0,0
4,IMPAVIDO,asthenia,0,0,0,0
...,...,...,...,...,...,...
4738,ESBRIET,insomnia,0,0,0,0
4739,ESBRIET,increases of alt,0,1,0,1
4740,ESBRIET,gastro-esophageal reflux disease,0,1,1,0
4741,ESBRIET,anorexia,0,1,0,0


In [20]:
for key in ('D', 'OB', 'OL'):
    tn, fp, fn, tp = confusion_matrix(predictions['label'], predictions[key]).ravel()
    print(f"{key:2s} Specificity: {tn/(tn+fp):5.3f}")

print()
for key in ('D', 'OB', 'OL'):
    tn, fp, fn, tp = confusion_matrix(predictions['label'], predictions[key]).ravel()
    print(f"{key:2s} Recall/Sens: {tp/(tp+fn):5.3f}")

print()
for key in ('D', 'OB', 'OL'):
    tn, fp, fn, tp = confusion_matrix(predictions['label'], predictions[key]).ravel()
    print(f"{key:2s} FPR        : {fp/(fp+tn):5.3f}")

print()
for key in ('D', 'OB', 'OL'):
    tn, fp, fn, tp = confusion_matrix(predictions['label'], predictions[key]).ravel()
    print(f"{key:2s} Precision  : {tp/(tp+fp):5.3f}")

print()
for key in ('D', 'OB', 'OL'):
    tn, fp, fn, tp = confusion_matrix(predictions['label'], predictions[key]).ravel()
    print(f"{key:2s} F1         : {tp/(tp+0.5*(fp+fn)):5.3f}")


D  Specificity: 0.962
OB Specificity: 0.575
OL Specificity: 0.774

D  Recall/Sens: 0.065
OB Recall/Sens: 0.857
OL Recall/Sens: 0.597

D  FPR        : 0.038
OB FPR        : 0.425
OL FPR        : 0.226

D  Precision  : 0.027
OB Precision  : 0.032
OL Precision  : 0.042

D  F1         : 0.038
OB F1         : 0.062
OL F1         : 0.078
