In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import openai
import json
from ids import open_ai_api_key
from thefuzz import fuzz, process
import re
import ujson
import logging
import inflect
from collections import Counter, defaultdict
import numpy as np
import evaluate
import spacy
import ast
import pandas as pd

openai.api_key = open_ai_api_key

# seqeval evaluation
seqeval = evaluate.load("seqeval")

# spacy tokenizer
nlp = spacy.blank("en")

# Create an engine object
p = inflect.engine()
# Set up logging configuration
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Method : Few-shot prompting

In [3]:
def retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid):
    ''' 
    Function to retrieve the abstract and spans from the processed_for_modeling.json file
    ----------
    path_to_processed_for_modeling : str
        The path to the processed_for_modeling.json file
    pmid : int
        The pmid of the article to retrieve the spans
    '''
    with open(path_to_processed_for_modeling, "r") as file:
        processed_for_modeling = ujson.load(file)
    spans = None
    for article in processed_for_modeling:
        if article["pmid"] == pmid:
            abstract = article["text"]
            spans = article["spans"]
            break
    if spans is None:
        logging.info("pmid not found in the list of documents")
    
    return abstract, spans

    
def tagged_abstract(path_to_processed_for_modeling, pmid):
    '''
    Function to retrieve the tagged abstract from the processed_for_modeling.json file
    ----------
    path_to_processed_for_modeling : str
        The path to the processed_for_modeling.json file
    pmid : int
        The pmid of the article to retrieve the spans
    '''
    with open(path_to_processed_for_modeling, "r") as file:
        processed_for_modeling = ujson.load(file)
    tagged_abstract = defaultdict(list)
    for article in processed_for_modeling:
        if article["pmid"] == pmid:
            for span in article["spans"]:
                tagged_abstract[span["tag"]].append(span["text"])
    return tagged_abstract


def create_spans_with_surrounding_text(text, spans):
    '''
    This function creates a new list of spans with the surrounding text (+- 1 word) of each span
    
    Parameters:
    ----------
    text : str
        The abstract text
    spans : list of dict
        The spans extracted from the text : start, end, label, tag and text
    '''
    new_spans = []
    # Use regex to split the text into words, including punctuation as separate tokens
    words_with_indices = [(m.group(0), m.start(), m.end()) for m in re.finditer(r'\S+|\s+', text)]

    def find_word_indices(span_start, span_end):
        start_word_index = 0
        end_word_index = 0
        for i, (word, start_idx, end_idx) in enumerate(words_with_indices):
            if start_idx <= span_start < end_idx:
                start_word_index = i
            if start_idx <= span_end <= end_idx:
                end_word_index = i
                break
        return start_word_index, end_word_index

    for span in spans:
        start = span['start']
        end = span['end']
        tag = span['tag']
        span_text = span['text']

        start_word_index, end_word_index = find_word_indices(start, end)

        # Define the surrounding words range
        surrounding_start_index = max(0, start_word_index - 2)
        surrounding_end_index = min(len(words_with_indices), end_word_index + 3)

        # Extract surrounding words
        surrounding_words = words_with_indices[surrounding_start_index:surrounding_end_index]
        surrounding_text = ''.join(word for word, _, _ in surrounding_words)

        new_span = {
            "tag": tag,
            "text": span_text,
            "surrounding": surrounding_text.strip()
        }

        new_spans.append(new_span)

    return new_spans



def extract_tagged_text(abstract, tagged_text, tag_to_label):
    '''
    Function to extract the position of the tagged text from the abstract given its surrounding text
    
    Parameters:
    ----------
    abstract : str
        The abstract text
    tagged_text : list of dict
        The tagged text
    tag_to_label : dict
        Mapping of tags to labels
    '''
    output = []
    
    for entity in tagged_text:
        surrounding_text = entity["surrounding"]
        entity_text = entity["text"]
        
        # Handle special characters like non-breaking spaces
        surrounding_text = surrounding_text.replace('\xa0', ' ')
        abstract = abstract.replace('\xa0', ' ')
        
        # Find the start position of the surrounding text in the abstract
        start_idx = abstract.lower().find(surrounding_text.lower())
        
        if start_idx != -1:
            # Find the start position of the entity text within the surrounding text
            surrounding_text_lower = surrounding_text.lower()
            entity_text_lower = entity_text.lower()
            
            if entity_text.isdigit():
                # Convert the number to words
                entity_text_in_words = p.number_to_words(entity_text).lower()
                # Check for both the number and its word representation
                entity_start_in_surrounding = surrounding_text_lower.find(entity_text_lower)
                found_form = "numeric"
                if entity_start_in_surrounding == -1:
                    entity_start_in_surrounding = surrounding_text_lower.find(entity_text_in_words)
                    found_form = "words"
            else:
                entity_start_in_surrounding = surrounding_text_lower.find(entity_text_lower)
                found_form = "text"
            
            if entity_start_in_surrounding != -1:
                # Calculate the actual start and end positions in the abstract
                actual_start_idx = start_idx + entity_start_in_surrounding
                
                if found_form == "numeric":
                    actual_end_idx = actual_start_idx + len(entity_text)
                elif found_form == "words":
                    actual_end_idx = actual_start_idx + len(entity_text_in_words)
                else:
                    actual_end_idx = actual_start_idx + len(entity_text)
                
                output.append({
                    "start": actual_start_idx,
                    "end": actual_end_idx,
                    "label": tag_to_label[entity["tag"]],
                    "tag": entity["tag"],
                    "text": entity["text"]
                })
    
    output.sort(key=lambda x: x["start"])
    return output

In [4]:
tag_to_label = {
    "Disease/Condition of Interest" : 8, #
    "Dosage" : 9, #
    "Drug Intervention" : 1 , #
    "Follow-up period" : 16, # 
    "Group Characteristic" : 6, # 
    "Group Name" : 3,#  
    "Sample Size" : 7, # 
    "Intervention Administration" : 13, # 
    "Intervention Duration" : 12, # 
    "Intervention Frequency" : 11, #  
    "Non-Pharmaceutical Intervention" : 17, # 
    "Non-Study Drug" : 14, # 
    "Outcome (Study Endpoint)" : 2, # 
    "Side Effects" : 15, # Side effects
    "Quantitative Measurement" : 0, # 
    "Statistical Significance" : 5, #
    "Study Duration" : 19, # 
    "Study Years" : 18, #
    "Type of Quant. Measure" : 10, # 
    "Units" : 4, #
}

In [5]:
extraction_prompts = {
        "Disease/Condition of Interest": "Return the disease/condition of interest. This is defined as what the drug aims to treat. Example usage: in 'metformin is being tested to treat PCOS', return 'PCOS'.",
        "Drug Intervention": "Return the name(s) of the drug(s) tested. Don't highlight qualifying terms. Example usage: in 'high-dose-aspirin', return 'aspirin'.",
        "Dosage": "Return the dosage of the drug(s) used in the study. Please return only the string of the numerical value used to convey the amount of drug given, excluding units. Example usage: in '2.5 mg Eliquis', return '2.5'. Example usage: in 'two grams', return 'two'.",
        "Sample Size": "Return how many patients were enrolled in the study. Please give only the exact string and no other words. Do not convert words into numbers. For example, 'twenty-seven' stays as 'twenty-seven'. Example usage: in 'a total of 420 patients were enrolled in the study', return '420'. Example usage: in 'study had n=100', return '100'.",
        "Follow-up period": "Return how long participants were tracked or examined after the initial intervention period. Example usage: in 'participants were followed for 1 year after intervention', return '1 year'.",
        "Group Characteristic": "Return the trait(s) used to describe a group/groups of patients in the study. Example usage: in 'post-menopausal women were studied', return 'post-menopausal women.' Example usage: in 'participants with diastolic pressure > 80 mmHg...', return 'diastolic pressure > 80 mmHg'.",
        "Group Name": "Return the group name(s) assigned to treatment or control group in the study Example usage: in 'first group had 20 mg (group A) and second group had 50 mg (group B)', return 'group A,group B'.",
        "Intervention Administration": "Return the method in which the drug is provided to the patient/subject. Example usage: in '200 mg fluoxetine was given intravenously', return 'intravenously'. Other example matches include 'oral', 'inhaled', 'subcutaneous'.",
        "Intervention Duration": "Return the amount of time the drug is taken/used. Example usage: in '2 g of aspirin was administered for 5 weeks', return '5 weeks'.",
        "Intervention Frequency": "Return how often the drug taken. Example matches include 'B.I.D', 'daily', 'q8H'.",
        "Non-Pharmaceutical Intervention": "If present, return a treatment that is not drug-related. This is an intervention being tested which is not a drug. Example usage: in 'patients undergoing hysterectomy were given 500 mg advil', return 'hysterectomy'.",
        "Non-Study Drug": "If present, return additionally mentioned drugs that are not being mainly studied. This is a drug supplied or mentioned which is not of relevance to the studies' outcomes. Example usage: in 'we tested 87 asthmatic patients on Sertraline with Fluticasone', return 'Sertraline'. Example usage: in 'patients received a single-dose of metoprolol, supplied with intravenous saline for two weeks', return 'saline'.",
        "Outcome (Study Endpoint)": "Return benchmarks that help evaluate the drug's efficacy or success, usually mentioned in the beginning and have quantitative support later in the abstract. The outcome(s) is\/are what is being measured or assessed to relay the drugs effects on the condition/ disease of interest. Example usage: in 'There was a significant increase in heart rate in group 1 compared to group 2 (85% vs 10%, p = 0.001)', return 'heart rate'.",
        "Side Effects": "If present, return side effects experienced by participants/subjects/patients while taking the study drug. Example usage: in 'Participants experienced fatigue, anxiety, etc. while taking metformin', return 'fatigue,anxiety'.",
        "Quantitative Measurement": "Return numerical values that support the outcome and provide context for understanding. This value is measured in the study to evaluate the drug's effects on the outcomes. Only highlight the number, include parentheses if given as percentage, include +/-. Example usage: in 'There was a significant increase in heart rate in group 1 compared to group 2 (85% vs 10%, p = 0.001)', return '85%,10%'.",
        "Statistical Significance": "Return statistical measurements used to describe quantitative data. Usually a p-value, return all p-values whether significant or nonsignificant. Example usage: in 'There was a significant increase in heart rate in group 1 compared to group 2 (85% vs 10%, p = 0.001)', return 'p = 0.001'.",
        "Study Duration": "Return how long the study is. Example usage: in 'An 8-week, double-crossover, placebo-controlled, clinical trial...', return '8-week'.",
        "Study Years": "Return the years the study takes place during. Example usage: in 'this was a 4 year study conducted from May 2013 to October 2017', return 'May 2013 to October 2017'. You should return the whole blurb (including the months).",
        "Type of Quant. Measure": "Return the classification of a statistic. Example matches include 'hazard ratio', 'confidence interval', simple statistics like 'mean', 'median', 'odds ratio', etc. 'p' is not a type of quant measure. Example usage: in 'The mean (+- SD) in diastolic blood pressure was measured as 95 mmHg (Confidence Interval of 95%: 0.59 - 1.3)', return 'mean,SD,Confidence Interval of 95%'.",
        "Units": "Return the unit of measurement used for a specific dosage or statistic. Example usage: in 'metformin 20 mg/kg/day for 6 months', return 'mg/kg'. If it comes after a number and it's not a true word, chances are it's a unit."
    }

format_spans = {
    "tag": "tag of the entity extracted",
    "text": "entity extracted",
    "surrounding": "surrounding of the text extracted",
}

In [7]:
path_to_processed_for_modeling = "../data/processed_for_modeling.json"

### Examples for few-shot prompt

In [8]:

pmid4080 = 4080
abstract4080, true_spans4080 = retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid4080)
spans4080 = create_spans_with_surrounding_text(text = abstract4080, spans = true_spans4080)

pmid65221 = 65221
abstract65221, true_spans65221 = retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid65221)
spans65221 = create_spans_with_surrounding_text(text = abstract65221, spans = true_spans65221)

pmid29208464 = 29208464
abstract29208464, true_spans29208464 = retrieve_abstract_and_spans(path_to_processed_for_modeling, 29208464)
spans29208464 = create_spans_with_surrounding_text(text = abstract29208464, spans = true_spans29208464)

pmid30872104 = 30872104
abstract30872104, true_spans30872104 = retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid30872104)
spans30872104 = create_spans_with_surrounding_text(text = abstract30872104, spans = true_spans30872104)

pmid35426326 = 35426326
abstract35426326, true_spans35426326 = retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid35426326)
spans35426326 = create_spans_with_surrounding_text(text = abstract35426326, spans = true_spans35426326)

pmid7484829 = 7484829
abstract7484829, true_spans7484829 = retrieve_abstract_and_spans(path_to_processed_for_modeling, pmid7484829)
spans7484829 = create_spans_with_surrounding_text(text = abstract7484829, spans = true_spans7484829)

In [9]:
true_spans35426326

[{'start': 708, 'end': 712, 'label': 7, 'tag': 'Sample Size', 'text': 'four'},
 {'start': 752,
  'end': 756,
  'label': 1,
  'tag': 'Drug Intervention',
  'text': 'L-T4'},
 {'start': 811,
  'end': 859,
  'label': 2,
  'tag': 'Outcome (Study Endpoint)',
  'text': 'no significant difference between the two groups'},
 {'start': 881, 'end': 884, 'label': 3, 'tag': 'Group Name', 'text': 'SCH'},
 {'start': 889, 'end': 894, 'label': 3, 'tag': 'Group Name', 'text': 'TPOAb'},
 {'start': 929,
  'end': 945,
  'label': 2,
  'tag': 'Outcome (Study Endpoint)',
  'text': 'live births rate'},
 {'start': 975,
  'end': 981,
  'label': 2,
  'tag': 'Outcome (Study Endpoint)',
  'text': 'higher'},
 {'start': 1019,
  'end': 1024,
  'label': 0,
  'tag': 'Quantitative Measurement',
  'text': '79.5%'},
 {'start': 1029,
  'end': 1034,
  'label': 0,
  'tag': 'Quantitative Measurement',
  'text': '70.8%'},
 {'start': 1148,
  'end': 1173,
  'label': 2,
  'tag': 'Outcome (Study Endpoint)',
  'text': 'no significant

### Prompt gpt function

In [10]:
def prompt(abstract, model = "gpt-3.5-turbo-0125") : 
    completion = openai.chat.completions.create(
        model = model,
        messages=[
            {
                "role": "user",
                "content": f""" 
                Task : Looking at the following abstract: {abstract65221}.\
                Retrieve all labels that are in : {extraction_prompts}. \
                If the same entity appears multiple times in the text, it should be extracted each time it appears.\
                Answer : {spans65221}. \
                Task : Looking at the following abstract: {abstract4080}.\
                Retrieve all labels that are in : {extraction_prompts}. \
                If the same entity appears multiple times in the text, it should be extracted each time it appears.\
                Answer : {spans4080}. \
                Task : Looking at the following abstract: {abstract29208464}.\
                Retrieve all labels that are in : {extraction_prompts}. \
                If the same entity appears multiple times in the text, it should be extracted each time it appears.\
                Answer : {spans29208464}. \
                Task : Looking at the following abstract: {abstract30872104}.\
                Retrieve all labels that are in : {extraction_prompts}. \
                If the same entity appears multiple times in the text, it should be extracted each time it appears.\
                Answer : {spans30872104}. \
                Task : Looking at the following abstract: {abstract35426326}.\
                Retrieve all labels that are in : {extraction_prompts}. \
                If the same entity appears multiple times in the text, it should be extracted each time it appears.\
                Answer : {spans35426326}. \
                Task : Looking at the following abstract: {abstract}.\
                Retrieve all labels that are in : {extraction_prompts}. \
                If the same entity appears multiple times in the text, it should be extracted each time it appears. \
                Answer : 
                Entities must be extracted in this format : {format_spans}.\
                Return the answer in a JSON with "entities" as key and don’t output anything beyond the required JSON file.
                """
            },
        ],
        max_tokens= 3072,
        temperature= 0,
    )
    return completion.choices[0].message.content


## Seqeval evaluation library

In [21]:
def convert_to_seqeval_format(y_true, y_pred, abstract):
    '''
    This function converts the entity data to seqeval format
    Parameters:
    ----------
    y_true : list
        Human annotated data
    y_pred : list
        Model predictions
    abstract : str
        The abstract to convert to seqeval format
    '''
    def label_tokens(annotations, abstract):
        tokens = abstract.split()
        labels = ["O"] * len(tokens)
        for ann in annotations:
            start_idx = len(abstract[:ann['start']].split())
            end_idx = start_idx + len(ann['text'].split())
            labels[start_idx] = f"B-{ann['tag']}"
            for i in range(start_idx + 1, end_idx):
                labels[i] = f"I-{ann['tag']}"
        return labels
    
    y_true_seqeval = label_tokens(y_true, abstract)
    y_pred_seqeval = label_tokens(y_pred, abstract)
    return y_true_seqeval, y_pred_seqeval

def label_tokens_from_offsets(text, annotations):
    doc = nlp(text)
    tokens = [token.text for token in doc]
    labels = ["O"] * len(tokens)

    for ann in annotations:
        start_char = ann['start']
        end_char = ann['end']
        start_token = next((i for i, token in enumerate(doc) if token.idx >= start_char), None)
        end_token = next((i for i, token in enumerate(doc) if token.idx >= end_char), None)
        
        if start_token is not None and end_token is not None:
            labels[start_token] = f"B-{ann['tag']}"
            for i in range(start_token + 1, end_token):
                labels[i] = f"I-{ann['tag']}"

    return labels


def compute_metrics(pmids, model="gpt-3.5-turbo-0125", tag_to_label=None) :
    all_y_true_seqeval = []
    all_y_pred_seqeval = []
    results_list = []
    
    for i in range(len(pmids)):
        print("i :", i)
        print("pmid :", pmids[i])
        abstract, true_spans = retrieve_abstract_and_spans(path_to_processed_for_modeling = path_to_processed_for_modeling, pmid = pmids[i])
        tagged_text = prompt(abstract, model)
        if model == "gpt-4o" :
            tagged_text = tagged_text.strip('```json').strip('```').strip()
        tagged_text_json = json.loads(tagged_text)
        pred_spans = extract_tagged_text(abstract=abstract, tagged_text=tagged_text_json["entities"], tag_to_label=tag_to_label)
        y_true_seqeval = label_tokens_from_offsets(text = abstract, annotations = true_spans)
        y_pred_seqeval = label_tokens_from_offsets(text = abstract, annotations = pred_spans)
        all_y_true_seqeval.append(y_true_seqeval)
        all_y_pred_seqeval.append(y_pred_seqeval)
        results = seqeval.compute(predictions=[y_pred_seqeval], references=[y_true_seqeval])
        results['pmid'] = pmids[i]
        results['model'] = model
        results_list.append(results)

        
    # Evaluate using seqeval
    overall_results = seqeval.compute(predictions=all_y_pred_seqeval, references=all_y_true_seqeval)
    overall_class_specific_f1 = {
        k: v["f1"] for k, v in overall_results.items() if not k.startswith("overall")
    }

    
    df_results = pd.DataFrame(results_list)
    columns = ['pmid', 'model', 'overall_accuracy', 'overall_precision', 'overall_recall', 'overall_f1'] # + \
                    # [col for col in df_results.columns if col not in ['pmid', 'overall_accuracy', 'overall_precision', 'overall_recall', 'overall_f1']]
    df_results = df_results[columns].rename(columns={
        'overall_accuracy': 'accuracy',
        'overall_precision': 'precision',
        'overall_recall': 'recall',
        'overall_f1': 'f1'
    })
    
    return {
        "accuracy": overall_results["overall_accuracy"],
        "precision": overall_results["overall_precision"],
        "recall": overall_results["overall_recall"],
        "f1": overall_results["overall_f1"],
        "class_specific_f1": overall_class_specific_f1,
        "detailed_results": overall_results,
    }, df_results

In [22]:
with open(path_to_processed_for_modeling, "r") as file:
    data = ujson.load(file)

In [23]:
test_data = [item for item in data if item["split"] == "test"]
pmids_test = [item["pmid"] for item in test_data]

In [24]:
pmids_test = pmids_test[:1]
pmids_test

[58651]

In [27]:
results, df = compute_metrics(pmids = pmids_test, 
                                 model="gpt-4o",
                                 tag_to_label=tag_to_label
                                 )
results

i : 0
pmid : 58651


INFO: HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'accuracy': 0.7838983050847458,
 'precision': np.float64(0.5416666666666666),
 'recall': np.float64(0.38235294117647056),
 'f1': np.float64(0.44827586206896547),
 'class_specific_f1': {'Disease/Condition of Interest': np.float64(0.0),
  'Drug Intervention': np.float64(0.25),
  'Follow-up period': np.float64(0.0),
  'Group Characteristic': np.float64(1.0),
  'Group Name': np.float64(0.0),
  'Intervention Administration': np.float64(0.0),
  'Non-Study Drug': np.float64(0.0),
  'Outcome (Study Endpoint)': np.float64(0.0),
  'Quantitative Measurement': np.float64(0.8),
  'Sample Size': np.float64(0.7142857142857143),
  'Statistical Significance': np.float64(0.6666666666666666)},
 'detailed_results': {'Disease/Condition of Interest': {'precision': np.float64(0.0),
   'recall': np.float64(0.0),
   'f1': np.float64(0.0),
   'number': np.int64(0)},
  'Drug Intervention': {'precision': np.float64(0.25),
   'recall': np.float64(0.25),
   'f1': np.float64(0.25),
   'number': np.int64(4)},
  'Fol

In [28]:
df

Unnamed: 0,pmid,model,accuracy,precision,recall,f1
0,58651,gpt-4o,0.783898,0.541667,0.382353,0.448276
