In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import openai
import json
from thefuzz import fuzz, process
import re
import ujson
import logging
import inflect
from collections import Counter, defaultdict
import numpy as np
import evaluate
import spacy
import pandas as pd


# seqeval evaluation
seqeval = evaluate.load("seqeval")

# spacy tokenizer
nlp = spacy.blank("en")

# Create an engine object
p = inflect.engine()
# Set up logging configuration
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# gpt label --> trial_sieve label
remapped_keys = {
    "disease" : "Disease/Condition of Interest", 
    "drug_intervention" : "Drug Intervention", 
    "drug_dosage" : "Dosage", 
    "sample_size" : "Sample Size", 
    "follow_up_period" : "Follow-up period", 
    "group_characteristic" : "Group Characteristic", 
    "group_name" : "Group Name", 
    "intervention_administration" : "Intervention Administration", 
    "intervention_duration" : "Intervention Duration", 
    "intervention_frequency" : "Intervention Frequency", 
    "non_pharmaceutical_intervention" : "Non-Pharmaceutical Intervention", 
    "non_study_drug" : "Non-Study Drug", 
    "outcome" : "Outcome (Study Endpoint)", 
    "qualitative_side_effects" : "Qualitative Side Effects", 
    "quantitative_measurement" : "Quantitative Measurement", 
    "statistical_significance" : "Statistical Significance", 
    "study_duration" : "Study Duration", 
    "study_years" : "Study Years", 
    "type_of_quant_measure" : "Type of Quant. Measure", 
    "units" : "Units",
}

tag_to_label = {
    "disease" : 8, #
    "drug_dosage" : 9, #
    "drug_intervention" : 1 , #
    "follow_up_period" : 16, # 
    "group_characteristic" : 6, # 
    "group_name" : 3,#  
    "sample_size" : 7, # 
    "intervention_administration" : 13, # 
    "intervention_duration" : 12, # 
    "intervention_frequency" : 11, #  
    "non_pharmaceutical_intervention" : 17, # 
    "non_study_drug" : 14, # 
    "outcome" : 2, # 
    "qualitative_side_effects" : 15, # 
    "quantitative_measurement" : 0, # 
    "statistical_significance" : 5, #
    "study_duration" : 19, # 
    "study_years" : 18, #
    "type_of_quant_measure" : 10, # 
    "units" : 4, #
}

extraction_keys = [
    "disease", 
    "drug_intervention", 
    "drug_dosage", 
    "sample_size", 
    "follow_up_period", 
    "group_characteristic", 
    "group_name", 
    "intervention_administration", 
    "intervention_duration", 
    "intervention_frequency", 
    "non_pharmaceutical_intervention", 
    "non_study_drug", 
    "outcome", 
    "qualitative_side_effects", 
    "quantitative_measurement", 
    "statistical_significance", 
    "study_duration", 
    "study_years", 
    "type_of_quant_measure", 
    "units"
]


In [4]:
path_to_processed_for_modeling = "../data/processed_for_modeling.json"
with open(path_to_processed_for_modeling, "r") as file:
    processed_for_modeling = ujson.load(file)
documents = processed_for_modeling
    

# Few-shot-prompting : GPT retrieves an entity as many times as it appears in the abstract

In [5]:
# gpt label --> trial_sieve label

tag_to_label = {
    "Disease/Condition of Interest" : 8, #
    "Dosage" : 9, #
    "Drug Intervention" : 1 , #
    "Follow-up period" : 16, # 
    "Group Characteristic" : 6, # 
    "Group Name" : 3,#  
    "Sample Size" : 7, # 
    "Intervention Administration" : 13, # 
    "Intervention Duration" : 12, # 
    "Intervention Frequency" : 11, #  
    "Non-Pharmaceutical Intervention" : 17, # 
    "Non-Study Drug" : 14, # 
    "Outcome (Study Endpoint)" : 2, # 
    "Side Effects" : 15, # Side effects
    "Quantitative Measurement" : 0, # 
    "Statistical Significance" : 5, #
    "Study Duration" : 19, # 
    "Study Years" : 18, #
    "Type of Quant. Measure" : 10, # 
    "Units" : 4, #
}

extraction_keys = [
    "disease", 
    "drug_intervention", 
    "drug_dosage", 
    "sample_size", 
    "follow_up_period", 
    "group_characteristic", 
    "group_name", 
    "intervention_administration", 
    "intervention_duration", 
    "intervention_frequency", 
    "non_pharmaceutical_intervention", 
    "non_study_drug", 
    "outcome", 
    "qualitative_side_effects", 
    "quantitative_measurement", 
    "statistical_significance", 
    "study_duration", 
    "study_years", 
    "type_of_quant_measure", 
    "units"
]

In [6]:
def retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid):
    ''' 
    Function to retrieve the abstract and spans from the processed_for_modeling.json file
    ----------
    path_to_processed_for_modeling : str
        The path to the processed_for_modeling.json file
    pmid : int
        The pmid of the article to retrieve the spans
    '''
    with open(path_to_processed_for_modeling, "r") as file:
        processed_for_modeling = ujson.load(file)
    spans = None
    for article in processed_for_modeling:
        if article["pmid"] == pmid:
            abstract = article["text"]
            spans = article["spans"]
            break
    if spans is None:
        logging.info("pmid not found in the list of documents")
    
    return abstract, spans

def extract_tagged_text(abstract, tagged_text):
    '''
    Function to extract the position of the tagged text from the abstract given its surrounding text
    
    Parameters:
    ----------
    abstract : str
        The abstract text
    tagged_text : list of dict
        The tagged text
    tag_to_label : dict
        Mapping of tags to labels
    '''
    output = []
    missed = []
    
    for entity in tagged_text:
        surrounding_text = entity["surrounding"]
        entity_text = entity["text"]
        
        if entity["tag"] not in tag_to_label:
            continue
        
        # Handle special characters like non-breaking spaces
        surrounding_text = surrounding_text.replace('\xa0', ' ')
        abstract = abstract.replace('\xa0', ' ')
        
        # Find the start position of the surrounding text in the abstract
        start_idx = abstract.lower().find(surrounding_text.lower())
        
        if start_idx != -1:
            # Find the start position of the entity text within the surrounding text
            surrounding_text_lower = surrounding_text.lower()
            entity_text_lower = entity_text.lower()
            
            if entity_text.isdigit():
                # Convert the number to words
                entity_text_in_words = p.number_to_words(entity_text).lower()
                # Check for both the number and its word representation
                entity_start_in_surrounding = surrounding_text_lower.find(entity_text_lower)
                found_form = "numeric"
                if entity_start_in_surrounding == -1:
                    entity_start_in_surrounding = surrounding_text_lower.find(entity_text_in_words)
                    found_form = "words"
            else:
                entity_start_in_surrounding = surrounding_text_lower.find(entity_text_lower)
                found_form = "text"
            
            if entity_start_in_surrounding != -1:
                # Calculate the actual start and end positions in the abstract
                actual_start_idx = start_idx + entity_start_in_surrounding
                
                if found_form == "numeric":
                    actual_end_idx = actual_start_idx + len(entity_text)
                elif found_form == "words":
                    actual_end_idx = actual_start_idx + len(entity_text_in_words)
                else:
                    actual_end_idx = actual_start_idx + len(entity_text)
                
                output.append({
                    "start": actual_start_idx,
                    "end": actual_end_idx,
                    "label": tag_to_label[entity["tag"]],
                    "tag": entity["tag"],
                    "text": entity["text"]
                })
        else :
            missed.append(entity)
    
    output.sort(key=lambda x: x["start"])
    return output, missed


In [7]:
def convert_to_seqeval_format(y_true, y_pred, abstract):
    '''
    This function converts the entity data to seqeval format
    Parameters:
    ----------
    y_true : list
        Human annotated data
    y_pred : list
        Model predictions
    abstract : str
        The abstract to convert to seqeval format
    '''
    def label_tokens(annotations, abstract):
        tokens = abstract.split()
        labels = ["O"] * len(tokens)
        for ann in annotations:
            start_idx = len(abstract[:ann['start']].split())
            end_idx = start_idx + len(ann['text'].split())
            labels[start_idx] = f"B-{ann['tag']}"
            for i in range(start_idx + 1, end_idx):
                labels[i] = f"I-{ann['tag']}"
        return labels
    
    y_true_seqeval = label_tokens(y_true, abstract)
    y_pred_seqeval = label_tokens(y_pred, abstract)
    return y_true_seqeval, y_pred_seqeval

def label_tokens_from_offsets(text, annotations):
    doc = nlp(text)
    tokens = [token.text for token in doc]
    labels = ["O"] * len(tokens)

    for ann in annotations:
        start_char = ann['start']
        end_char = ann['end']
        start_token = next((i for i, token in enumerate(doc) if token.idx >= start_char), None)
        end_token = next((i for i, token in enumerate(doc) if token.idx >= end_char), None)
        
        if start_token is not None and end_token is not None:
            labels[start_token] = f"B-{ann['tag']}"
            for i in range(start_token + 1, end_token):
                labels[i] = f"I-{ann['tag']}"

    return labels

def compute_metrics_v2(pmids, pmid_to_extracted_entities) :
    all_y_true_seqeval = []
    all_y_pred_seqeval = []
    results_list = []
    
    for i, pmid in enumerate(pmids):
        if f'{pmid}' not in pmid_to_extracted_entities:
            print(f"pmid {pmid} not in pmid_to_extracted_entities")
            continue
        abstract, true_spans = retrieve_abstract_and_spans(path_to_processed_for_modeling = path_to_processed_for_modeling, pmid = pmids[i])
        pred_spans, missed = extract_tagged_text(abstract=abstract, tagged_text=pmid_to_extracted_entities[f"{pmid}"])
    
        y_true_seqeval = label_tokens_from_offsets(text = abstract, annotations = true_spans)
        y_pred_seqeval = label_tokens_from_offsets(text = abstract, annotations = pred_spans)
        all_y_true_seqeval.append(y_true_seqeval)
        all_y_pred_seqeval.append(y_pred_seqeval)
        results = seqeval.compute(predictions=[y_pred_seqeval], references=[y_true_seqeval])
        results['pmid'] = pmids[i]
        results_list.append(results)
        
        if i%100 == 0:
            print("i :", i)
        

    # Evaluate using seqeval
    overall_results = seqeval.compute(predictions=all_y_pred_seqeval, references=all_y_true_seqeval)
    overall_class_specific_f1 = {
        k: v["f1"] for k, v in overall_results.items() if not k.startswith("overall")
    }

    
    df_results = pd.DataFrame(results_list)
    columns = ['pmid', 'overall_accuracy', 'overall_precision', 'overall_recall', 'overall_f1'] # + \
                    # [col for col in df_results.columns if col not in ['pmid', 'overall_accuracy', 'overall_precision', 'overall_recall', 'overall_f1']]
    df_results = df_results[columns].rename(columns={
        'overall_accuracy': 'accuracy',
        'overall_precision': 'precision',
        'overall_recall': 'recall',
        'overall_f1': 'f1'
    })
    
    return {
        "accuracy": overall_results["overall_accuracy"],
        "precision": overall_results["overall_precision"],
        "recall": overall_results["overall_recall"],
        "f1": overall_results["overall_f1"],
        "class_specific_f1": overall_class_specific_f1,
        "detailed_results": overall_results,
    }, df_results


In [8]:
path_to_gpt_output = 'results/2024-07-02/19:22:26/extraction/output.jsonl'

gpt_output = []
with open(path_to_gpt_output, 'r') as file:
    for line in file:
        gpt_output.append(json.loads(line))
    
gpt_pmids = [int(el['pmid']) for el in gpt_output]

pmids_test = [item['pmid'] for item in documents if item['split'] == 'test']
print(len(pmids_test))
gpt_pmids = pmids_test

pmid_to_output = {entry['pmid']: entry for entry in gpt_output}
pmid_to_extracted_entities = {entry['pmid']: entry['entities'] for entry in gpt_output}

238


In [9]:
pmid_to_extracted_entities['58651']

[{'tag': 'Drug Intervention',
  'text': 'Horse antihuman thymocyte globulin',
  'surrounding': 'Horse antihuman thymocyte globulin (HAHTG) combined'},
 {'tag': 'Drug Intervention',
  'text': 'HAHTG',
  'surrounding': 'Horse antihuman thymocyte globulin (HAHTG) combined'},
 {'tag': 'Non-Study Drug',
  'text': 'prednisone',
  'surrounding': 'combined with prednisone and'},
 {'tag': 'Non-Study Drug',
  'text': 'azathioprine',
  'surrounding': 'prednisone and azathioprine (lmuran)'},
 {'tag': 'Sample Size',
  'text': '50',
  'surrounding': 'controlled sutdy in 50 renal'},
 {'tag': 'Group Characteristic',
  'text': 'renal allograft recipients',
  'surrounding': '50 renal allograft recipients. Side'},
 {'tag': 'Drug Intervention',
  'text': 'HAHTG',
  'surrounding': 'Side effects of HAHTG administration'},
 {'tag': 'Intervention Administration',
  'text': 'intravenously',
  'surrounding': 'HAHTG administration given intravenously were'},
 {'tag': 'Group Name',
  'text': 'treated group',
  's

In [10]:
results2, df2 = compute_metrics_v2(pmids = gpt_pmids, pmid_to_extracted_entities = pmid_to_extracted_entities)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


i : 0
pmid 7535080 not in pmid_to_extracted_entities
pmid 9616339 not in pmid_to_extracted_entities
pmid 10423597 not in pmid_to_extracted_entities
pmid 10938493 not in pmid_to_extracted_entities
pmid 11286949 not in pmid_to_extracted_entities
pmid 12213517 not in pmid_to_extracted_entities
i : 100
pmid 14758137 not in pmid_to_extracted_entities
pmid 15066135 not in pmid_to_extracted_entities
pmid 16138566 not in pmid_to_extracted_entities
pmid 22717420 not in pmid_to_extracted_entities
pmid 23695545 not in pmid_to_extracted_entities
pmid 26596670 not in pmid_to_extracted_entities
pmid 27624700 not in pmid_to_extracted_entities
i : 200
pmid 29622587 not in pmid_to_extracted_entities
pmid 30354702 not in pmid_to_extracted_entities
pmid 30388256 not in pmid_to_extracted_entities
pmid 32649216 not in pmid_to_extracted_entities
pmid 33509804 not in pmid_to_extracted_entities
pmid 33566107 not in pmid_to_extracted_entities
pmid 33685298 not in pmid_to_extracted_entities
pmid 33725664 not in

In [11]:
results2

{'accuracy': 0.8394551747236798,
 'precision': np.float64(0.5554556893762358),
 'recall': np.float64(0.4721882640586797),
 'f1': np.float64(0.5104485008672668),
 'class_specific_f1': {'Disease/Condition of Interest': np.float64(0.3696969696969697),
  'Dosage': np.float64(0.746268656716418),
  'Drug Intervention': np.float64(0.49166666666666664),
  'Follow-up period': np.float64(0.2413793103448276),
  'Group Characteristic': np.float64(0.1584454409566517),
  'Group Name': np.float64(0.3183673469387755),
  'Intervention Administration': np.float64(0.31654676258992803),
  'Intervention Duration': np.float64(0.5350877192982456),
  'Intervention Frequency': np.float64(0.496124031007752),
  'Non-Pharmaceutical Intervention': np.float64(0.13793103448275862),
  'Non-Study Drug': np.float64(0.35294117647058826),
  'Outcome (Study Endpoint)': np.float64(0.22361984626135573),
  'Quantitative Measurement': np.float64(0.6522315510384445),
  'Sample Size': np.float64(0.7661375661375661),
  'Side Eff

In [12]:
df2

Unnamed: 0,pmid,accuracy,precision,recall,f1
0,58651,0.995763,0.971429,1.000000,0.985507
1,527347,0.813559,1.000000,0.428571,0.600000
2,779588,0.897196,0.388889,0.388889,0.388889
3,1366257,0.847059,0.481481,0.406250,0.440678
4,1703608,0.723602,0.666667,0.367347,0.473684
...,...,...,...,...,...
205,32847676,0.846591,0.500000,0.483871,0.491803
206,33009489,0.871111,0.536585,0.458333,0.494382
207,33222712,0.886364,1.000000,0.500000,0.666667
208,33242419,0.846834,0.685714,0.461538,0.551724


## Error Analysis Verification

In [30]:
filtered_df2 = df2[(df2["f1"] < 0.15) & (df2["f1"] > 0.10)]
filtered_df2

Unnamed: 0,pmid,accuracy,precision,recall,f1
184,7484829,0.829971,0.125,0.142857,0.133333
1073,22717420,0.819639,0.135135,0.138889,0.136986
1190,25448925,0.476071,0.142857,0.098039,0.116279
1225,26510933,0.768293,0.375,0.085714,0.139535
1327,28729361,0.841146,0.125,0.125,0.125
1331,28929323,0.87664,0.136364,0.142857,0.139535
1498,33849926,0.615764,0.121212,0.142857,0.131148
1534,35483753,0.692494,0.108108,0.181818,0.135593


In [None]:
pmid = 33583283
abstract, true_spans = retrieve_abstract_and_spans(path_to_processed_for_modeling = path_to_processed_for_modeling, pmid = pmid)
pred_spans, missed = extract_tagged_text(abstract=abstract, tagged_text=pmid_to_extracted_entities[f"{pmid}"])

y_true_seqeval = label_tokens_from_offsets(text = abstract, annotations = true_spans)
y_pred_seqeval = label_tokens_from_offsets(text = abstract, annotations = pred_spans)
results = seqeval.compute(predictions=[y_pred_seqeval], references=[y_true_seqeval])

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
abstract

"CT visual quantitative evaluation of hypertensive patients with coronavirus disease (COVID-19): Potential influence of angiotensin converting enzyme inhibitors / angiotensin receptor blockers on severity of lung involvement.\nThere is not enough data on the effect of angiotensin-converting enzyme inhibitors (ACEIs)/angiotensin receptor blockers (ARBs) on lung involvement in patients with COVID-19 pneumonia and hypertension (HT). Our aim was to compare the lung involvement of the HT patients hospitalized for COVID-19 using ACEIs/ARBs with the patients taking other anti-HT medications. : Patients who have a diagnosis of HT among the patients treated for laboratory-confirmed COVID-19 between 31 March 2020 and 28 May 2020 were included in the study. One hundred and twenty-four patients were divided into two as ACEIs/ARBs group (n\xa0=\xa075) and non-ACEIs/ARBs group (n\xa0=\xa049) according to the anti-HT drug used. The chest CT involvement areas of these two groups were evaluated quantit

In [None]:
pmid_to_extracted_entities[f"{pmid}"]

[{'tag': 'Disease/Condition of Interest',
  'text': 'hypertension',
  'surrounding': 'patients with COVID-19 pneumonia and hypertension (HT)'},
 {'tag': 'Disease/Condition of Interest',
  'text': 'COVID-19',
  'surrounding': 'patients with COVID-19 pneumonia and hypertension (HT)'},
 {'tag': 'Drug Intervention',
  'text': 'angiotensin-converting enzyme inhibitors',
  'surrounding': 'effect of angiotensin-converting enzyme inhibitors (ACEIs)'},
 {'tag': 'Drug Intervention',
  'text': 'angiotensin receptor blockers',
  'surrounding': 'angiotensin receptor blockers (ARBs)'},
 {'tag': 'Sample Size',
  'text': 'One hundred and twenty-four',
  'surrounding': 'One hundred and twenty-four patients were divided'},
 {'tag': 'Group Name',
  'text': 'ACEIs/ARBs group',
  'surrounding': 'as ACEIs/ARBs group (n = 75)'},
 {'tag': 'Group Name',
  'text': 'non-ACEIs/ARBs group',
  'surrounding': 'non-ACEIs/ARBs group (n = 49)'},
 {'tag': 'Quantitative Measurement',
  'text': '75',
  'surrounding': 'ACE

In [None]:
pred_spans

[{'start': 267,
  'end': 307,
  'label': 1,
  'tag': 'Drug Intervention',
  'text': 'angiotensin-converting enzyme inhibitors'},
 {'start': 316,
  'end': 345,
  'label': 1,
  'tag': 'Drug Intervention',
  'text': 'angiotensin receptor blockers'},
 {'start': 390,
  'end': 398,
  'label': 8,
  'tag': 'Disease/Condition of Interest',
  'text': 'COVID-19'},
 {'start': 413,
  'end': 425,
  'label': 8,
  'tag': 'Disease/Condition of Interest',
  'text': 'hypertension'},
 {'start': 697,
  'end': 726,
  'label': 18,
  'tag': 'Study Years',
  'text': '31 March 2020 and 28 May 2020'},
 {'start': 755,
  'end': 782,
  'label': 7,
  'tag': 'Sample Size',
  'text': 'One hundred and twenty-four'},
 {'start': 817,
  'end': 833,
  'label': 3,
  'tag': 'Group Name',
  'text': 'ACEIs/ARBs group'},
 {'start': 839,
  'end': 841,
  'label': 0,
  'tag': 'Quantitative Measurement',
  'text': '75'},
 {'start': 847,
  'end': 867,
  'label': 3,
  'tag': 'Group Name',
  'text': 'non-ACEIs/ARBs group'},
 {'start':

In [None]:
missed

[{'tag': 'Quantitative Measurement',
  'text': '3.2%',
  'surrounding': '4 (%3.2) asymptomatic'},
 {'tag': 'Statistical Significance',
  'text': 'p < .001',
  'surrounding': 'other anti-HT medication group (mean±SD, 4.40 ± 1.89) (p < .001)'},
 {'tag': 'Statistical Significance',
  'text': 'p < .001',
  'surrounding': 'common type (mean±SD, 5.76 ± 3.07) (p < .001)'}]

In [None]:
results

{'Disease/Condition of Interest': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 1},
 'Drug Intervention': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 4},
 'Group Characteristic': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 5},
 'Group Name': {'precision': 1.0,
  'recall': 0.3333333333333333,
  'f1': 0.5,
  'number': 6},
 'Intervention Duration': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 1},
 'Outcome (Study Endpoint)': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 4},
 'Quantitative Measurement': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 5},
 'Sample Size': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 4},
 'Statistical Significance': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 2},
 'Study Years': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0},
 'Type of Quant. Measure': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 4},
 'overall_prec