To reflect the reasons that led to the problematic results through numerical analysis, we will try quantifying and visualizing the challenges in the current pipeline.

In [118]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
from scipy.stats import ttest_ind, f_oneway

## Test

In [120]:
path = "../src/data/medmcqa/test.json"
test_set = pd.read_json(path)
test_set['mod_question'] = f"for the given question, choose the correct answer from the options list. \n\nquestion: {test_set['question']}\n\noptions list: [option_a:{test_set['opa']},option_b:{test_set['opb']},option_c:{test_set['opc']},option_d:{test_set['opd']}]\n\n"
test_set

Unnamed: 0,question,answer,answer_index,opa,opb,opc,opd,generated_explanation,mod_question
0,The best finish line for anterior metal cerami...,Shoulder with bevel,opd,Chamfer with bevel,Heavy chamfer,Shoulder,Shoulder with bevel,The best finish line for anterior metal cerami...,0 The best finish line for anterior metal...
1,Access cavity of mandibular 1st molar is:,Rounded triangle,opb,Oval,Rounded triangle,Rhomboid,None of the above,Access cavity of mandibular 1st molar is: Roun...,0 The best finish line for anterior metal...
2,Which of the following muscle is not supplied ...,Superior oblique,opa,Superior oblique,Medial rectus,Inferior rectus,Inferior oblique,Which of the following muscle is not supplied ...,0 The best finish line for anterior metal...
3,A person of eonism derives pleasure from.,Wearing clothes of opposite sex,opa,Wearing clothes of opposite sex,Fondling female body pas,Rubbing genitalia against body of other person,Seeing the opposite paner nude,A person of eonism derives pleasure from. Wear...,0 The best finish line for anterior metal...
4,18 year old female presents with an ovarian ma...,Dysgerminoma,opa,Dysgerminoma,Endodermal sinus tumor,Malignant terratoma,Mucinous cystadeno carcinoma,18 year old female presents with an ovarian ma...,0 The best finish line for anterior metal...
...,...,...,...,...,...,...,...,...,...
828,Mandibular process of each side fuse to form:,Both.,opc,Lower lip.,Lower jaw.,Both.,None.,Mandibular process of each side fuse to form: ...,0 The best finish line for anterior metal...
829,Keyhole-shaped visual field defect is seen in ...,Lateral geniculate body,opc,Optic disk,Optic chiasma,Lateral geniculate body,Occipital lobe,Keyhole-shaped visual field defect is seen in ...,0 The best finish line for anterior metal...
830,A 32 weeks pregnant female presented with labo...,Preterm labour,opa,Preterm labour,IUGR,IUD,Cervical infection,A 32 weeks pregnant female presented with labo...,0 The best finish line for anterior metal...
831,Calcification of roots of deciduous teeth is c...,4 years,opb,2 years,4 years,6 years,8 years,Calcification of roots of deciduous teeth is c...,0 The best finish line for anterior metal...


In [121]:
test_set['mod_question'] = test_set.apply(
    lambda row: f"For the given question, choose the correct answer from the options_list.\n\n"
                f"question: {row['question']}\n\n"
                f"options_list: [option_a:{row['opa']},option_b:{row['opb']},option_c:{row['opc']},option_d:{row['opd']}]\n\n",
    axis=1
)
test_set.to_csv('mini_test.csv',index=False)
test_set['benchmark_answer'] = test_set['mod_question'].apply(lambda x: generate_response(x))

Unnamed: 0,question,answer,answer_index,opa,opb,opc,opd,generated_explanation,mod_question
0,The best finish line for anterior metal cerami...,Shoulder with bevel,opd,Chamfer with bevel,Heavy chamfer,Shoulder,Shoulder with bevel,The best finish line for anterior metal cerami...,"for the given question, choose the correct ans..."
1,Access cavity of mandibular 1st molar is:,Rounded triangle,opb,Oval,Rounded triangle,Rhomboid,None of the above,Access cavity of mandibular 1st molar is: Roun...,"for the given question, choose the correct ans..."
2,Which of the following muscle is not supplied ...,Superior oblique,opa,Superior oblique,Medial rectus,Inferior rectus,Inferior oblique,Which of the following muscle is not supplied ...,"for the given question, choose the correct ans..."
3,A person of eonism derives pleasure from.,Wearing clothes of opposite sex,opa,Wearing clothes of opposite sex,Fondling female body pas,Rubbing genitalia against body of other person,Seeing the opposite paner nude,A person of eonism derives pleasure from. Wear...,"for the given question, choose the correct ans..."
4,18 year old female presents with an ovarian ma...,Dysgerminoma,opa,Dysgerminoma,Endodermal sinus tumor,Malignant terratoma,Mucinous cystadeno carcinoma,18 year old female presents with an ovarian ma...,"for the given question, choose the correct ans..."
...,...,...,...,...,...,...,...,...,...
828,Mandibular process of each side fuse to form:,Both.,opc,Lower lip.,Lower jaw.,Both.,None.,Mandibular process of each side fuse to form: ...,"for the given question, choose the correct ans..."
829,Keyhole-shaped visual field defect is seen in ...,Lateral geniculate body,opc,Optic disk,Optic chiasma,Lateral geniculate body,Occipital lobe,Keyhole-shaped visual field defect is seen in ...,"for the given question, choose the correct ans..."
830,A 32 weeks pregnant female presented with labo...,Preterm labour,opa,Preterm labour,IUGR,IUD,Cervical infection,A 32 weeks pregnant female presented with labo...,"for the given question, choose the correct ans..."
831,Calcification of roots of deciduous teeth is c...,4 years,opb,2 years,4 years,6 years,8 years,Calcification of roots of deciduous teeth is c...,"for the given question, choose the correct ans..."


#### Analyze Entities

In [117]:
path = "../src/data/medmcqa/dfs/short_corpus_entities.csv"
short_qa = pd.read_csv(path, index_col='id')
num_records = short_qa.shape[0]
cols = [(i, col) for i, col in enumerate(short_qa.columns) if col not in ['text', 'id', 'AGE', 'SEX']]
uniques_list = []
unique_entities_dict = {}
print("====================================================")
for i, col in cols:
    not_na = short_qa[~short_qa.iloc[:, i].isna()]

    print(f"Found {not_na.shape[0]} / {num_records} records with {col} != NaN")
    rep_mask = not_na.iloc[:, i].value_counts() >= 2
    len_mask = not_na.iloc[:, i].apply(lambda x: len(x) > 2)
    rep_df = not_na[not_na.iloc[:, i].isin(rep_mask[rep_mask].index)]
    if rep_df.shape[0] == 0:
        print(f"No duplicates found for {col}")
        continue
    rep_df = rep_df[rep_df.iloc[:, i].apply(lambda x: len(x) > 2)]
    uniques_df = rep_df[rep_df[col].duplicated(keep='last')]
    dup_count = rep_df.shape[0] - uniques_df.shape[0]
    uniques_list.extend(uniques_df.index)
    print(f"Found {dup_count} duplicated records based on {col}")
    print("====================================================")

    unique_values = set()
    for items in not_na[col].apply(lambda x: eval(x)):
        unique_values.update(items)
    unique_entities_dict[col] = unique_values
clean_df = short_qa[~short_qa.index.isin(uniques_list)]

print(f"Clean DataFrame : {clean_df.shape[0]} records ({num_records} records originaly)\n")
print("\n")
#print(f"Unique entities: {unique_entities_dict}")

Found 517 / 3330 records with MEDICATION != NaN
Found 25 duplicated records based on MEDICATION
Found 1025 / 3330 records with SIGN_SYMPTOM != NaN
Found 32 duplicated records based on SIGN_SYMPTOM
Found 1449 / 3330 records with BIOLOGICAL_STRUCTURE != NaN
Found 61 duplicated records based on BIOLOGICAL_STRUCTURE
Found 1403 / 3330 records with DISEASE_DISORDER != NaN
Found 66 duplicated records based on DISEASE_DISORDER
Clean DataFrame : 3061 records (3330 records originaly)


In [None]:
def HITS(original_query: str, retrieved_k_docs: list, n: int):
    """
    Rerank the top k retrieved documents using the HITS algorithm.
    :param n: number of documents to return
    :param original_query:  original query string
    :param retrieved_k_docs: List of dictionaries containing 'id' and 'metadata' for retrieved documents.
                             Each document metadata should contain a list of references or related documents.
    :return: A list of the n most relevant documents (sorted by authority scores).
    """
    G = nx.DiGraph()

    for doc in retrieved_k_docs:
        doc_id = doc['id']
        G.add_node(doc_id, metadata=doc.get('metadata', {}))

    for doc in retrieved_k_docs:
        doc_id = doc['id']
        related_docs = retrieve_from_index(doc['metadata']['text'])
        for r_doc in related_docs:
            related_doc_id = r_doc['id']
            if related_doc_id in G.nodes:
                G.add_edge(doc_id, related_doc_id)

    hits_scores = nx.hits(G, normalized=True)
    authority_scores = hits_scores[1]

    ranked_docs = sorted(retrieved_k_docs, key=lambda doc: authority_scores.get(doc['id'], 0), reverse=True)

    return ranked_docs[:n]

### Analyze the Embedding Space
Visualize Embeddings: Use dimensionality reduction techniques like t-SNE or UMAP to project the embeddings into 2D or 3D space. Highlight embeddings from different corpora to assess clustering or separation.
Similarity Distribution: Compute cosine similarity between embeddings of the retrieved documents and the query. Plot histograms or distributions for "relevant" and "irrelevant" documents to show overlaps or inconsistencies.

In [None]:
def visualize_embeddings(embeddings, labels, method='tsne'):
    """
    Project the embeddings into 2D or 3D space by utilizing dimensionality reduction techniques
    :param embeddings: 
    :param labels: 
    :param method: 
    :return: 
    """
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42)
    elif method == 'pca':
        reducer = PCA(n_components=2)
    else:
        raise ValueError("Unsupported method. Use 'tsne' or 'pca'.")

    reduced = reducer.fit_transform(embeddings)
    plt.figure(figsize=(10, 6))
    for label in set(labels):
        indices = np.where(labels == label)
        plt.scatter(reduced[indices, 0], reduced[indices, 1], label=f'Corpus {label}', alpha=0.6)
    plt.legend()
    plt.title(f'Embedding Visualization ({method.upper()})')
    plt.show()

#### Similarity Distribution: 
Compute cosine similarity between embeddings of the retrieved documents and the query. Plot histograms or distributions for "relevant" and "irrelevant" documents to show overlaps or inconsistencies.

In [None]:
def plot_similarity_distribution(query_embedding, document_embeddings, relevance_labels):
    """
    Computes cosine similarity between a query and document embeddings, and plots similarity 
    distributions for relevant and irrelevant documents.
    :param query_embedding: np.array,(d, ) the embedding vector for the query.
    :param document_embeddings: np.array, shape (n, d) the embedding vectors for the retrieved documents.
    :param relevance_labels: np.array, shape (n, ) binary relevance labels (1 for relevant, 0 for irrelevant).
    """

    similarities = cosine_similarity(document_embeddings, query_embedding.reshape(1, -1)).flatten()
    relevant_similarities = similarities[relevance_labels == 1]
    irrelevant_similarities = similarities[relevance_labels == 0]

    plt.figure(figsize=(10, 6))
    plt.hist(relevant_similarities, bins=20, alpha=0.6, color='green', label='Relevant')
    plt.hist(irrelevant_similarities, bins=20, alpha=0.6, color='red', label='Irrelevant')
    plt.axvline(np.mean(relevant_similarities), color='green', linestyle='--', label='Mean Relevant')
    plt.axvline(np.mean(irrelevant_similarities), color='red', linestyle='--', label='Mean Irrelevant')
    plt.title("Cosine Similarity Distribution for Relevant and Irrelevant Documents")
    plt.xlabel("Cosine Similarity")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

In [None]:
def compute_relevance_scores(query_entities, doc_entities_list):
    """
    Compute relevance scores based on the normalized number of shared entities 
    between the query and each document.

    :param query_entities: A list of entities in the query.
    :param doc_entities_list: A list where each element is a list of 
      entities for a specific document.
    :return: list: Normalized relevance scores for each document.
    """
    query_set = set(query_entities)
    scores = []

    for doc_entities in doc_entities_list:
        doc_set = set(doc_entities)
        shared_entities = query_set.intersection(doc_set)

        normalization_factor = len(query_set.union(doc_set))
        score = len(shared_entities) / normalization_factor if normalization_factor > 0 else 0
        scores.append(score)

    return scores

In [None]:
query_embedding = model.encode(query).tolist()
documents_embeddings, documents_scores, relevanc_labels = [], [], []
docs_dict = dict()
for i, doc in retrived_docs:
    doc_id = doc['id']
    doc_embedding = doc['embedding']
    document_score = doc['score']
    relevance_score = get_relevance_score(doc)
    document_embeddings.append(doc['embedding'])
    relevance_labels.append(relevance_score)
    docs_dict[doc_id] = [document_score, relevance_score]

docs_df = pd.from_dict(docs_dict)
plot_similarity_distribution(query_embedding, document_embeddings, relevanc_labels)

#### Token Distribution

In [None]:
def plot_token_length_distribution(token_lengths, labels):
    plt.figure(figsize=(10, 6))
    sns.boxplot(x=labels, y=token_lengths)
    plt.xlabel("Corpus")
    plt.ylabel("Token Length")
    plt.title("Token Length Distribution by Corpus")
    plt.show()

#### Noise in Embeddings

In [None]:
def compute_intra_corpus_similarity(embeddings, labels):
    results = {}
    for label in set(labels):
        corpus_embeddings = embeddings[labels == label]
        similarities = np.inner(corpus_embeddings, corpus_embeddings.T)
        results[label] = np.mean(similarities)
    return results


def compute_cross_corpus_similarity(embeddings, labels):
    unique_labels = set(labels)
    results = {}
    for i, label1 in enumerate(unique_labels):
        for j, label2 in enumerate(unique_labels):
            if j <= i:
                continue
            corpus1 = embeddings[labels == label1]
            corpus2 = embeddings[labels == label2]
            similarities = np.inner(corpus1, corpus2.T)
            results[f'{label1}-{label2}'] = np.mean(similarities)
    return results

#### Relevance Scores

In [None]:
def plot_relevance_vs_similarity(scores, relevance_labels):
    plt.figure(figsize=(10, 6))
    plt.scatter(scores, relevance_labels, alpha=0.6)
    plt.xlabel("Cosine Similarity Score")
    plt.ylabel("Relevance")
    plt.title("Relevance vs Cosine Similarity")
    plt.show()

#### Impact of Query Characteristics

In [None]:
def analyze_query_performance(queries, scores, relevance, query_lengths):
    import pandas as pd
    df = pd.DataFrame({
        'query': queries,
        'length': query_lengths,
        'score': scores,
        'relevance': relevance
    })
    grouped = df.groupby('length').mean()
    grouped[['score', 'relevance']].plot(kind='bar', figsize=(10, 6))
    plt.title("Query Performance by Length")
    plt.show()

#### Experiment with Baseline Models

In [None]:
def compare_with_baseline(baseline_results, semantic_results, relevance_labels):
    import sklearn.metrics as metrics
    baseline_pr = metrics.precision_recall_curve(relevance_labels, baseline_results)
    semantic_pr = metrics.precision_recall_curve(relevance_labels, semantic_results)

    plt.plot(baseline_pr[1], baseline_pr[0], label="Baseline")
    plt.plot(semantic_pr[1], semantic_pr[0], label="Semantic Search")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.legend()
    plt.show()

#### Statistical Tests

In [None]:
def perform_statistical_tests(data1, data2):
    t_stat, t_p = ttest_ind(data1, data2, equal_var=False)
    f_stat, f_p = f_oneway(data1, data2)
    return {"t_test": {"statistic": t_stat, "p_value": t_p},
            "anova": {"statistic": f_stat, "p_value": f_p}}

#### Document Contribution by Corpus

In [None]:
def plot_document_contribution(corpus_labels, relevance_labels):
    df = pd.DataFrame({
        'corpus': corpus_labels,
        'relevance': relevance_labels
    })
    relevance_by_corpus = df.groupby('corpus').mean()
    relevance_by_corpus.plot(kind='bar', figsize=(10, 6))
    plt.title("Document Contribution and Relevance by Corpus")
    plt.ylabel("Average Relevance")
    plt.show()