In [15]:
# Compute the evaluation metrics
from copy import deepcopy
from collections import defaultdict

# git clone https://github.com/davidsbatista/NER-Evaluation
import NER_Evaluation.ner_evaluation.ner_eval
from NER_Evaluation.ner_evaluation.ner_eval import collect_named_entities
from NER_Evaluation.ner_evaluation.ner_eval import compute_metrics
from NER_Evaluation.ner_evaluation.ner_eval import compute_precision_recall_wrapper
targets = [
'N_male',
'N_female',
'SBP_in_male_mean',
'SBP_in_female_mean',
'SBP_in_male_std',
'SBP_in_female_std',
'DBP_in_male_mean',
'DBP_in_female_mean',
'DBP_in_male_std',
'DBP_in_female_std'
]

In [16]:
# Read the test set with groudtruth.
import pandas as pd

label_file = 'test.txt'
labels = [[]]
with open(label_file) as f:
    for line in f:
        line = line.strip()
        if len(line) == 0:
            labels.append([])
        else:
            _, label = line.split()
            label = label.replace('E-', 'I-').replace('S-', 'B-') # use BIO format
            labels[-1].append(label)
if len(labels[-1]) == 0:
    labels = labels[:-1]
print(len(labels), len(labels[0]), len(labels[1]))

74 86 27


# Read few shot/DANN predictions

In [17]:
# Read the predictions from a few shot model.
fewShot_pred_file = 'predictions-bp-pred.txt'
fewShot_preds = []
with open(fewShot_pred_file) as f:
    for line in f:
        preds = line.strip().split()
        preds = [x.replace('E-', 'I-').replace('S-', 'B-') for x in preds]
        fewShot_preds.append(preds)
print(len(fewShot_preds), len(fewShot_preds[0]), len(fewShot_preds[1]))

74 86 27


# Read LLM predictions

In [41]:
#!pip install openpyxl
llm_pred_file = '/labs/sarkerlab/yguo262/blood_pressure_project/LLM/datasets/ann_050724_new/test.gpt35.xlsx'
llm_df = pd.read_excel(llm_pred_file)
print(llm_df.columns)

docID_order = pd.read_excel('bp_test_fewshot_results.xlsm', sheet_name='prediction_DANN')['pmc_s'].values.tolist()
new_index = [0] * len(docID_order)
for i, row in llm_df.iterrows():
    for j, docID in enumerate(docID_order):
        if row['pmc_s'] == docID:
            new_index[j] = i
print(new_index)
llm_df = llm_df.iloc[new_index]

Index(['pmc_s', 'text', 'N_male', 'N_female', 'SBP_in_male_mean',
       'SBP_in_female_mean', 'SBP_in_male_std', 'SBP_in_female_std',
       'DBP_in_male_mean', 'DBP_in_female_mean', 'DBP_in_male_std',
       'DBP_in_female_std', 'Notes', 'positive_case', 'full_prompt',
       'response', 'bp_male_mean', 'bp_female_mean', 'bp_male_std',
       'bp_female_std', 'pred'],
      dtype='object')
[13, 4, 2, 18, 21, 16, 11, 7, 15, 17, 8, 22, 14, 5, 6, 20, 1, 3, 12, 10, 0, 9, 19]


In [42]:
# Convert LLM output into BIO format
#!pip install nltk
from nltk.tokenize import wordpunct_tokenize, sent_tokenize

def list_match(a, b):
    assert len(a) == len(b)
    for i in range(len(a)):
        if a[i] != b[i]:
            return False
    return True

def find_target(long_list, short_list):
    for i in range(len(long_list)-len(short_list)):
        same_li = list_match(long_list[i:i+len(short_list)], short_list)
        if same_li:
            return i
    return -1
    
def conv_to_BIO(df):
    BIO_preds = [[]]
    for _, row in df.iterrows():
        label = row['positive_case']
        text = row['text'].strip()
        sentences = sent_tokenize(text)
        all_tokens = [wordpunct_tokenize(sentence) for sentence in sentences]            
        all_target_tokens = [wordpunct_tokenize(str(row[target]).replace('.0', '')) for target in targets]
        
        if label:
            match_count = 0
            for tokens in all_tokens:
                pos_map = {}
                for k, target_tokens in enumerate(all_target_tokens):
                    start_pos = find_target(tokens, target_tokens)
                    if start_pos >= 0:
                        match_count += 1
                        for pos in range(start_pos, start_pos+len(target_tokens)):
                            if pos == start_pos:
                                pos_map[pos] = 'B-' + targets[k]
                            else:
                                pos_map[pos] = 'I-' + targets[k]

                for k, token in enumerate(tokens):
                    if k in pos_map:
                        BIO_preds[-1].append(pos_map[k])
                    else:
                        BIO_preds[-1].append('O')
                BIO_preds.append([])            
        else:
            for tokens in all_tokens:
                for token in tokens:
                    BIO_preds[-1].append('O')    
                BIO_preds.append([])
    return BIO_preds

llm_preds = conv_to_BIO(llm_df)
if len(llm_preds[-1]) == 0:
    llm_preds = llm_preds[:-1]
print(len(llm_preds), len(llm_preds[0]), len(llm_preds[1]))

74 86 27


In [43]:
def compute(gold_labels, pred_labels):
    
    metrics_results = {'correct': 0, 'incorrect': 0, 'partial': 0,
                       'missed': 0, 'spurious': 0, 'possible': 0, 'actual': 0, 'precision': 0, 'recall': 0}
    
    # overall results
    results = {'strict': deepcopy(metrics_results),
               'ent_type': deepcopy(metrics_results),
               'partial':deepcopy(metrics_results),
               'exact':deepcopy(metrics_results)
              }
    
    # results aggregated by entity type
    evaluation_agg_entities_type = {e: deepcopy(results) for e in targets}
    
    for true_ents, pred_ents in zip(gold_labels, pred_labels):
    
        # compute results for one message
        tmp_results, tmp_agg_results = compute_metrics(
            collect_named_entities(true_ents), collect_named_entities(pred_ents),  targets
        )
        #print(tmp_results)
    
        # aggregate overall results
        for eval_schema in results.keys():
            for metric in metrics_results.keys():
                results[eval_schema][metric] += tmp_results[eval_schema][metric]
    
        # Calculate global precision and recall
        #print(results)
        results = compute_precision_recall_wrapper(results)
    
    
        # aggregate results by entity type
    
        for e_type in targets:
    
            for eval_schema in tmp_agg_results[e_type]:
    
                for metric in tmp_agg_results[e_type][eval_schema]:
    
                    evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]
    
            # Calculate precision recall at the individual entity level
    
            evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(evaluation_agg_entities_type[e_type])
    return results, evaluation_agg_entities_type
    

In [44]:
# Compute the metrics
fewShot_results, fewShot_evaluation_agg_entities_type = compute(labels, fewShot_preds)
llm_results, llm_evaluation_agg_entities_type = compute(labels, llm_preds)
print(fewShot_results, llm_results)

{'ent_type': {'correct': 16, 'incorrect': 11, 'partial': 0, 'missed': 50, 'spurious': 15, 'possible': 77, 'actual': 42, 'precision': 0.38095238095238093, 'recall': 0.2077922077922078}, 'partial': {'correct': 27, 'incorrect': 0, 'partial': 0, 'missed': 50, 'spurious': 15, 'possible': 77, 'actual': 42, 'precision': 0.6428571428571429, 'recall': 0.35064935064935066}, 'strict': {'correct': 16, 'incorrect': 11, 'partial': 0, 'missed': 50, 'spurious': 15, 'possible': 77, 'actual': 42, 'precision': 0.38095238095238093, 'recall': 0.2077922077922078}, 'exact': {'correct': 27, 'incorrect': 0, 'partial': 0, 'missed': 50, 'spurious': 15, 'possible': 77, 'actual': 42, 'precision': 0.6428571428571429, 'recall': 0.35064935064935066}} {'ent_type': {'correct': 40, 'incorrect': 0, 'partial': 0, 'missed': 37, 'spurious': 4, 'possible': 77, 'actual': 44, 'precision': 0.9090909090909091, 'recall': 0.5194805194805194}, 'partial': {'correct': 40, 'incorrect': 0, 'partial': 0, 'missed': 37, 'spurious': 4, 'po

In [45]:
def get_f1(p, r):
    return 2*p*r/(p+r) if p+r != 0 else 0
    
def fmt_agg_entities(evaluation_agg_entities_type, metric='ent_type'):
    # Print in a format good for excel
    output = pd.DataFrame({'varible':[], 'precision':[], 'recall':[], 'f':[]})
    for target in targets:
        p = evaluation_agg_entities_type[target][metric]['precision']
        r = evaluation_agg_entities_type[target][metric]['recall']
        f = 0 if p+r == 0 else get_f1(p, r)
        output.loc[len(output)] = (target, p, r, f)
    return output
output = fmt_agg_entities(fewShot_evaluation_agg_entities_type)
output.to_excel('fewshot_result.xlsx')
output

Unnamed: 0,varible,precision,recall,f
0,N_male,0.2,0.5,0.285714
1,N_female,0.0,0.0,0.0
2,SBP_in_male_mean,0.1,0.25,0.142857
3,SBP_in_female_mean,0.25,0.625,0.357143
4,SBP_in_male_std,0.117647,0.333333,0.173913
5,SBP_in_female_std,0.0,0.0,0.0
6,DBP_in_male_mean,0.055556,0.125,0.076923
7,DBP_in_female_mean,0.0625,0.125,0.083333
8,DBP_in_male_std,0.0625,0.2,0.095238
9,DBP_in_female_std,0.0,0.0,0.0


In [46]:
output = fmt_agg_entities(llm_evaluation_agg_entities_type)
output.to_excel('llm_result.xlsx')
output

Unnamed: 0,varible,precision,recall,f
0,N_male,0.5,0.5,0.5
1,N_female,0.5,0.444444,0.470588
2,SBP_in_male_mean,0.5,0.5,0.5
3,SBP_in_female_mean,0.5,0.5,0.5
4,SBP_in_male_std,0.5,0.666667,0.571429
5,SBP_in_female_std,0.555556,0.555556,0.555556
6,DBP_in_male_mean,0.5,0.5,0.5
7,DBP_in_female_mean,0.5,0.5,0.5
8,DBP_in_male_std,0.428571,0.6,0.5
9,DBP_in_female_std,0.5,0.5,0.5


In [40]:
# Evaluate LLM in batch
#!pip install openpyxl

docID_order = pd.read_excel('bp_test_fewshot_results.xlsm', sheet_name='prediction_DANN')['pmc_s'].values.tolist()
new_index = [0] * len(docID_order)
for i, row in llm_df.iterrows():
    for j, docID in enumerate(docID_order):
        if row['pmc_s'] == docID:
            new_index[j] = i
print(new_index)

for model in ['gpt35']:
    llm_pred_file = f'/labs/sarkerlab/yguo262/blood_pressure_project/LLM/datasets/ann_050724_new/test.{model}.xlsx'
    llm_df = pd.read_excel(llm_pred_file)
    llm_df = llm_df.iloc[new_index]
    
    llm_preds = conv_to_BIO(llm_df)
    if len(llm_preds[-1]) == 0:
        llm_preds = llm_preds[:-1]
    print(len(llm_preds), len(llm_preds[0]), len(llm_preds[1]))

    llm_results, llm_evaluation_agg_entities_type = compute(labels, llm_preds)
    precision = llm_results['ent_type']['precision']
    recall = llm_results['ent_type']['recall']
    f = get_f1(precision, recall)
    print(f'{model}\t{precision}\t{recall}\t{f}')
    output = fmt_agg_entities(llm_evaluation_agg_entities_type, metric='ent_type')
    output.to_excel(f'llm_result.{model}.xlsx')



[13, 4, 2, 18, 21, 16, 11, 7, 15, 17, 8, 22, 14, 5, 6, 20, 1, 3, 12, 10, 0, 9, 19]
74 86 27
gpt35	0.9090909090909091	0.5194805194805194	0.6611570247933884
