In [11]:
import helper.pubmed_search as pubs
from helper.pubmed_search import QueryExpansionManager
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import random
import time
import json
import math
import csv
import os


def precision_at_k(retrieved_docs, relevant_docs, k=10):
    # k = min(k, len(retrieved_docs))  # Handle case where retrieved_docs < k
    k = len(relevant_docs)  # Handle case where retrieved_docs < k
    relevant_in_top_k = [doc for doc in retrieved_docs[:k] if doc in relevant_docs]
    return 100*len(relevant_in_top_k) / k if k > 0 else 0.0


def average_precision_at_k(retrieved_docs, relevant_docs, k=10):
    # k = min(k, len(retrieved_docs))  # Handle case where retrieved_docs < k
    k = len(relevant_docs) # Handle case where retrieved_docs < k
    if k == 0:
        return 0.0
    num_relevant = 0
    precision_sum = 0    
    if len(retrieved_docs) != 0:
        for i in range(1, k + 1):
            if retrieved_docs[i - 1] in relevant_docs:
                num_relevant += 1
                precision_sum += num_relevant / i
    
    # Use the smaller of k or total number of relevant documents
    # return 100*precision_sum / min(len(relevant_docs), k) if num_relevant > 0 else 0.0
    return 100*precision_sum / k if num_relevant > 0 else 0.0


def dcg(retrieved_docs, relevant_docs, k):
    dcg_value = 0.0
    if len(retrieved_docs)!=0:
        for i in range(k):
            if retrieved_docs[i] in relevant_docs:
                dcg_value += 1 / math.log2(i + 2)  # i + 2 because of 0-based indexing
    return dcg_value
def idcg(relevant_docs, k):
    idcg_value = 0.0
    # for i in range(min(len(relevant_docs), k)):
    for i in range(k):
        idcg_value += 1 / math.log2(i + 2)
    return idcg_value
def ndcg_at_k(retrieved_docs, relevant_docs, k=10):
    # k = min(k, len(retrieved_docs))  # Handle case where retrieved_docs < k
    k = len(relevant_docs)  # Handle case where retrieved_docs < k
    if k == 0:
        return 0.0

    # def dcg(retrieved_docs, relevant_docs, k):
    #     dcg_value = 0.0
    #     for i in range(k):
    #         if retrieved_docs[i] in relevant_docs:
    #             dcg_value += 1 / math.log2(i + 2)  # i + 2 because of 0-based indexing
    #     return dcg_value

    # def idcg(relevant_docs, k):
    #     idcg_value = 0.0
    #     for i in range(min(len(relevant_docs), k)):
    #         idcg_value += 1 / math.log2(i + 2)
    #     return idcg_value
    
    dcg_value = dcg(retrieved_docs, relevant_docs, k)
    idcg_value = idcg(relevant_docs, k)
    
    return 100*dcg_value / idcg_value if idcg_value > 0 else 0.0


def extract_subset_of_evaluation_data(evaluation_data, filename, fraction = 0.5, random_seed = 31):   
    random.seed(random_seed) 
    subset = random.sample(evaluation_data, int(len(evaluation_data)*fraction))
    with open (filename, 'w') as outfile:
        json.dump(subset, outfile)
    print('Extracted', len(subset), 'documents out of', len(evaluation_data))

def output_full_metrics_to_file(retrieved_ids_list, ground_truth_list, name, file_name):
    # precision at k values
    metric_list = []
    for i in range (len(retrieved_ids_list)):
        metric_list.append(precision_at_k(retrieved_ids_list[i], ground_truth_list[i]))
    mean_precision_at_k = sum(metric_list)/len(metric_list) 

    # average precision at k values
    metric_list = []
    for i in range (len(retrieved_ids_list)):
        metric_list.append(average_precision_at_k(retrieved_ids_list[i], ground_truth_list[i]))
    mean_average_precision_at_k = sum(metric_list)/len(metric_list)

    # ndcg at k values
    metric_list = []
    for i in range (len(retrieved_ids_list)):
        metric_list.append(ndcg_at_k(retrieved_ids_list[i], ground_truth_list[i]))
    mean_ndcg_at_k = sum(metric_list)/len(metric_list)

    output_list = [name, mean_precision_at_k, mean_average_precision_at_k, mean_ndcg_at_k]    

    # Check if file exists
    file_exists = os.path.isfile(file_name)
    # Open the file in append mode ('a') and create it if it doesn't exist
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        # Write the header if the file is new
        if not file_exists:
            # Customize the header if needed, here assuming just column numbers
            header = [f"Column {i+1}" for i in range(len(output_list))]
            writer.writerow(header)
        # Write the data list
        writer.writerow(output_list)
    print(f"List successfully written to {file_name}")


def length_check(retrieved_ids_list, ground_truth_list):
    if len(retrieved_ids_list)!=len(ground_truth_list):
        print('Different length!')
    inconsistent_indexes = 0
    for i in range (len(retrieved_ids_list)):
        retrieved_ids_length = len(retrieved_ids_list[i])
        ground_truth_length = len(ground_truth_list[i])
        if retrieved_ids_length!=ground_truth_length:
            print('Inconsistency at index:', i, 'where retrieved ids have length:', retrieved_ids_length, 'and ground truth has length', ground_truth_length)
            inconsistent_indexes +=1
    return inconsistent_indexes


Extracts out a subset (25%) of the full test-set data and store it in a JSON for brevity. Already done so things are commented out.

In [2]:
# with open ('evaluation/BioASQ-training11b/training11b.json') as training_file:
#     full_evaluation_data = json.load(training_file)
# extract_subset_of_evaluation_data(full_evaluation_data['questions'], 'evaluation/testing_script.json', 0.002)

# 2. Evaluation For Initial Retrieval

This part measures the accuracy of the results returned from the initial query without any query expansion. It covers:
1. Testing immediate results from a pubmed search.
2. Testing the results ranked by an embedding model that compares article content to the 

Extract the subset evaluation data and run evaluation on it.



In [10]:
# Method to rank relevance of documents
def get_initial_retrieved_and_ground_truth_docs(evaluation_set, remove_stop_words = False, articles_to_retrieve = 20, qe_manager = None):
    retrieved_ids_list = []
    ground_truth_list = []
    # Run it through the system:
    for idx, entry in enumerate(evaluation_set):
        # # Include this for diagnostics only, otherwise comment it out
        # if idx > 0:
        #     continue
        # Extract out the query
        query = entry['body']        
        # When running the evaluation, there is a tendency for the internet to break so we need to continuously try to connnect
        runs = 0
        while True:
            if qe_manager is None:
                time.sleep(1) # rest for 1 second so as not to overwhelm PubMed's API if there is no re-ranking
            # query the question through the database and extract the relevant documents
            # try:
            print('Query number', idx, 'is: ', query)
            if qe_manager is not None:
                query_embedding = qe_manager.embed_mesh_headings_preloaded_model(query)
            query_response = pubs.get_query_response(query, remove_stop_words=remove_stop_words, articles_to_retrieve=articles_to_retrieve)
            ground_truth = [x.split('/')[-1] for x in entry['documents']] 
            ground_truth_length = len (ground_truth)   
            ground_truth_list.append(ground_truth)   
            # Checks if there is idlist in the results, then append it to the list
            # If not (usually because of pubmed search errors or the preprocessing removed all words from the query), append an empty list                
            if 'idlist' in query_response['esearchresult']:
                retrieved_ids = query_response['esearchresult']['idlist']
                if retrieved_ids == []:
                    retrieved_ids_list.append([])
                    break
                # If we decide to rank the retrieval before returning the list
                if qe_manager is not None:
                    ranking_list = []
                    retrieved_article_details = pubs.get_article_details_from_id(retrieved_ids)
                    for key, value in retrieved_article_details.items():
                        title = value.get('Title') or ""
                        abstract = value.get('Abstract') or ""
                        article_title_abstract = title + abstract
                        article_title_abstract_embeddings = qe_manager.embed_mesh_headings_preloaded_model(article_title_abstract)
                        similarity = cosine_similarity(query_embedding, article_title_abstract_embeddings)[0][0]
                        ranking_list.append({'Article': key, 'Similarity': similarity})
                    # If the ranking list is shorter than ground truth, need to pad it
                    if len(ranking_list) < ground_truth_length:
                        entries_to_add = ground_truth_length - len(ranking_list)
                        for i in range(entries_to_add):
                            ranking_list.append({'Article': 'None', 'Similarity': 0})
                    ranking_list_df = pd.DataFrame(ranking_list).sort_values(by=['Similarity'], ascending=False)
                    ranking_list_df.to_csv('evaluation/simple_retrieval_articles.csv')
                    # If the ranking list is longer than ground truth, need to trim it
                    retrieved_ids_list.append(ranking_list_df.head(ground_truth_length)['Article'].to_list())
                else:
                    if len(retrieved_ids) < ground_truth_length:
                        entries_to_add = ground_truth_length - len(retrieved_ids)
                        for i in range(entries_to_add):
                            retrieved_ids.append('')
                    else: # for the case where the lengths are equal or ground truth is shorter
                        retrieved_ids = retrieved_ids[0:ground_truth_length]
                    retrieved_ids_list.append(retrieved_ids)  
            else: 
                retrieved_ids_list.append([])                      
            break # once information is extracted, stop the rerun
            # # This is just to pre-ampt anything that would happen above during long evaluation runs.
            # except Exception as e:
            #     print('Error occured when querying pubmed:', e)
            #     runs +=1
            #     if runs > 3: # after 3 rounds of rerun due to errors, just move on
            #         print('This entry is not counted:', entry['body'])
            #         break 
    return retrieved_ids_list, ground_truth_list


### 2.1 Immediate Results from PubMed Entrez Search

In [28]:
with open ('evaluation/small subset.json') as infile:
    evaluation = json.load(infile)

# pure retrieval on a fixed dataset (takes 31.7s compared to retrieval + p@k ranking at 33.3s)
retrieved_ids_list, ground_truth_list = get_initial_retrieved_and_ground_truth_docs(evaluation, remove_stop_words=False)
output_full_metrics_to_file(retrieved_ids_list, ground_truth_list, 'baseline (initial retrieval, 10 articles retrieved, no reranking, 233 documents) trimming', 'evaluation/evaluation_metrics.csv')

Query number 0 is:  What is the function of the AIRE gene at the embryonic stage?
Query number 1 is:  Which type of analysis does DeSeq2 perform?
Query number 2 is:  Which intraflagellar transport (IFT) motor protein has been linked to human skeletal ciliopathies?
Query number 3 is:  Where in the body, is ghrelin secreted?
Query number 4 is:  What is the effect of SAHA treatment in Huntington's disease?
Query number 5 is:  What is membrane scission?
Query number 6 is:  What is the color of the protein Ranasmurfin?
Query number 7 is:  What is the mechanism of action of ocrelizumab for treatment of multiple sclerosis?
Query number 8 is:  What treatment was studied in the KEYNOTE-522 trial?
Query number 9 is:  Is Stat4 a transcription factor?
Query number 10 is:  By which methods can we evaluate the reliability of a phylogenetic tree?
Query number 11 is:  What is the physiological role of LKB1 involved in Peutz-Jeghers syndrome?
Query number 12 is:  Is the Dictyostelium discoideum proteom

### 2.2 Reranking of Initial Results from PubMed Entrez Search

In this section, we will first retrieve the documents, embed their contents and the query, do a reranking of their order using cosine similarity between the embedded query and embedded document content, then output the re-ranked list and ground truth list.

You can select whichever model you want to evaluate the reranking, but if you choose 'openai', please pass your api key as argument (api_key = 'your api key') into QueryExpansionManager(model_name, 'helper/descriptors.json').

Note that there is no query expansion or suggestion at this stage.

In [49]:
from helper.config import api_key
# model_name = "sentence-transformers/all-mpnet-base-v2"
# model_name = "w601sxs/b1ade-embed"  
# model_name = "dmis-lab/biobert-v1.1"  
model_name = 'openai'
blade_embed_qe_manager = QueryExpansionManager(model_name, 'helper/descriptors.json', api_key=api_key)

with open ('evaluation/small subset.json') as infile:
    evaluation = json.load(infile)

retrieved_ids_list, ground_truth_list = get_initial_retrieved_and_ground_truth_docs(evaluation, remove_stop_words=False, articles_to_retrieve=10, qe_manager=blade_embed_qe_manager)
output_full_metrics_to_file(retrieved_ids_list, ground_truth_list, 'initial retrieval with Blade reranking, 10 articles, trimming', 'evaluation/evaluation_metrics.csv')

Query number 0 is:  What is the function of the AIRE gene at the embryonic stage?
Query number 1 is:  Which type of analysis does DeSeq2 perform?
Query number 2 is:  Which intraflagellar transport (IFT) motor protein has been linked to human skeletal ciliopathies?
Query number 3 is:  Where in the body, is ghrelin secreted?
Query number 4 is:  What is the effect of SAHA treatment in Huntington's disease?
Query number 5 is:  What is membrane scission?
Query number 6 is:  What is the color of the protein Ranasmurfin?
Query number 7 is:  What is the mechanism of action of ocrelizumab for treatment of multiple sclerosis?
Query number 8 is:  What treatment was studied in the KEYNOTE-522 trial?
Query number 9 is:  Is Stat4 a transcription factor?
Query number 10 is:  By which methods can we evaluate the reliability of a phylogenetic tree?
Query number 11 is:  What is the physiological role of LKB1 involved in Peutz-Jeghers syndrome?
Query number 12 is:  Is the Dictyostelium discoideum proteom

ConnectTimeout: HTTPSConnectionPool(host='eutils.ncbi.nlm.nih.gov', port=443): Max retries exceeded with url: /entrez/eutils/efetch.fcgi?db=pubmed&id=39171136%2C39129620%2C37750650%2C37082918%2C36605074%2C35463710%2C32950946%2C32824919%2C29430308%2C28243770&retmode=xml&rettype=abstract (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000016BEDF5E310>, 'Connection to eutils.ncbi.nlm.nih.gov timed out. (connect timeout=None)'))

# 3. Evaluation for Query Expansion and Suggestion

In [29]:
import pandas as pd
import helper.pubmed_search as pubs

def get_reretrieval_and_ground_truth_docs(evaluation_set, qe_manager):
    retrieved_ids_list = []
    ground_truth_list = []
    # Run it through the system:
    for idx, entry in enumerate(evaluation_set):
        # # for diagnostics only
        # if idx > 0:
        #     continue
        query = entry['body']      
        # When running the evaluation, there is a tendency for the internet to break so we need to continuously try to connnect
        # rerun = True
        runs = 0
        while True:
            # query the question through the database and extract the relevant documents
            try:
                print('Query number', idx, 'is: ', query)
                ground_truth = [x.split('/')[-1] for x in entry['documents']]  
                ground_truth_length = len(ground_truth)                    

                # Get the initial round of documents
                heading_entries, article_entries = qe_manager.get_entries_from_query(query)  
                title_abstract_threshold = 20 
                heading_threshold = 40   
                
                print('Initial article length:', len(article_entries))
                # Deal with articles
                if article_entries!=[]:
                    article_entries_df = pd.DataFrame(article_entries).drop_duplicates(['name']).sort_values(by=['suitability'], ascending=False)  
                    # print('Initial article length:', len(article_entries_df))
                    article_entries_df.to_csv('evaluation/article_entries.csv')
                    filtered_article_entries = article_entries_df[article_entries_df['suitability'] > title_abstract_threshold].to_dict(orient='records')  
                else:
                    retrieved_ids_list.append([])
                    print('Benchmark length is:', ground_truth_length)
                    ground_truth_list.append(ground_truth) 
                    break # if no articles returned, move on

                if heading_entries!=[]:
                    heading_entries_df = pd.DataFrame(heading_entries).drop_duplicates(['name']).sort_values(by=['suitability'], ascending=False) 
                    heading_entries_df.to_csv('evaluation/heading_entries.csv')
                    filtered_heading_entries = heading_entries_df[heading_entries_df['suitability'] > heading_threshold].to_dict(orient='records')    
                
                requery = pubs.create_semantic_neighbourhood_query(filtered_heading_entries, filtered_article_entries)
                print('Re-Query is: ', requery)
                if requery != '':
                    _, requeried_article_entries = qe_manager.get_entries_from_query(requery, requery = True, articles_to_retrieve = 20) 
                else:
                    requeried_article_entries = []
                
                print('Initial requery length is:', len(requeried_article_entries))
                # And finally, get the list of retrieved_ids to be returned
                if (requeried_article_entries != []) and (article_entries != []):
                    requeried_article_entries_df = pd.DataFrame(requeried_article_entries)
                    requeried_article_entries_df.to_csv('evaluation/requeried_article_entries.csv')
                    combined_article_entries_df = pd.concat([article_entries_df, requeried_article_entries_df]).drop_duplicates(['name']).sort_values(by=['suitability'], ascending=False)
                    combined_article_entries_df.to_csv('evaluation/combined_article_entries_df.csv')                    
                    retrieved_ids = combined_article_entries_df.head(ground_truth_length)['name'].to_list()
                    # The following code ensures length of retrieved ids are same as ground truth
                    if len(retrieved_ids) < ground_truth_length:
                        entries_to_add = ground_truth_length - len(retrieved_ids)
                        for i in range (entries_to_add):
                            retrieved_ids.append('')
                elif article_entries != []: # if there are no articles in initially retrieved set, there won't be requeried articles as well
                    retrieved_ids = article_entries_df.head(ground_truth_length)['name'].to_list()
                    if len(retrieved_ids) < ground_truth_length:
                        entries_to_add = ground_truth_length - len(retrieved_ids)
                        for i in range (entries_to_add):
                            retrieved_ids.append('')
                else: # this one should be handled earlier but we put another here just for safety
                    retrieved_ids = []
                retrieved_ids_list.append(retrieved_ids)
                ground_truth_list.append(ground_truth) 
                print('Length of retrieved list: ', len(retrieved_ids))
                print('Ground truth length is:', ground_truth_length)
                break # once information is extracted, stop the rerun                
            # This is just to pre-ampt anything that would happen above during long evaluation runs.
            except Exception as e:
                print('Error occured when querying pubmed:', e)
                runs +=1
                if runs > 3: # after 10 rounds of rerun due to errors, just move on
                    print('This entry is not counted:', entry['body'])
                    break 
    return retrieved_ids_list, ground_truth_list

In [30]:
from helper.pubmed_search import QueryExpansionManager  
import json 

# model_name1 = "sentence-transformers/all-mpnet-base-v2"
# model_name = "w601sxs/b1ade-embed"  
model_name = "dmis-lab/biobert-v1.1" 
blade_embed_qe_manager = QueryExpansionManager(model_name, 'helper/descriptors.json')

# from helper.config import api_key
# blade_embed_qe_manager = QueryExpansionManager('openai', 'helper/descriptors.json', api_key)

# with open ('evaluation/subset.json') as infile:
#     evaluation = json.load(infile)
with open ('evaluation/small subset.json') as infile:
    evaluation = json.load(infile)
    
# pure retrieval on a fixed dataset (takes 31.7s compared to retrieval + p@k ranking at 33.3s)
retrieved_ids_list, ground_truth_list = get_reretrieval_and_ground_truth_docs(evaluation, blade_embed_qe_manager)
output_full_metrics_to_file(retrieved_ids_list, ground_truth_list, 'biobert (re retrieval, 10 articles retrieved, no reranking, 233 documents) trimming', 'evaluation/evaluation_metrics.csv')



Query number 0 is:  What is the function of the AIRE gene at the embryonic stage?
Initial article length: 9
Re-Query is:  ("Cell Proliferation"[MeSH Terms] ) OR ("Gene Expression Regulation, Developmental"[MeSH Terms] ) OR ("Thymus Gland"[MeSH Terms] AND "Epithelial Cells"[MeSH Terms] )
Initial requery length is: 20
Length of retrieved list:  5
Ground truth length is: 5
Query number 1 is:  Which type of analysis does DeSeq2 perform?
Initial article length: 10
Re-Query is:  ("Papillomavirus Infections"[MeSH Terms] AND "Head and Neck Neoplasms"[MeSH Terms] AND "MicroRNAs"[MeSH Terms] ) OR ("Endogenous Retroviruses"[MeSH Terms] AND "Leukocytes, Mononuclear"[MeSH Terms] AND "B-Lymphocytes"[MeSH Terms] AND "Antibodies, Monoclonal, Humanized"[MeSH Terms] ) OR ("Arthropod Antennae"[MeSH Terms] AND "Gene Expression Profiling"[MeSH Terms] AND "Calliphoridae"[MeSH Terms] ) OR ("Fontan Procedure"[MeSH Terms] AND "Liver Diseases"[MeSH Terms] AND "Heart Defects, Congenital"[MeSH Terms] ) OR ("Diabe

In [20]:
len(ground_truth_list)

237

In [27]:
first_half = ground_truth_list[1:][0:187]

In [26]:
length_check(retrieved_ids_list, ground_truth_list[1:])

Different length!
Inconsistency at index: 0 where retrieved ids have length: 5 and ground truth has length 15
Inconsistency at index: 1 where retrieved ids have length: 15 and ground truth has length 4
Inconsistency at index: 2 where retrieved ids have length: 4 and ground truth has length 18
Inconsistency at index: 3 where retrieved ids have length: 18 and ground truth has length 4
Inconsistency at index: 4 where retrieved ids have length: 4 and ground truth has length 15
Inconsistency at index: 5 where retrieved ids have length: 15 and ground truth has length 2
Inconsistency at index: 6 where retrieved ids have length: 2 and ground truth has length 20
Inconsistency at index: 7 where retrieved ids have length: 20 and ground truth has length 6
Inconsistency at index: 8 where retrieved ids have length: 6 and ground truth has length 4
Inconsistency at index: 9 where retrieved ids have length: 4 and ground truth has length 13
Inconsistency at index: 10 where retrieved ids have length: 13 

145