In [1]:
import numpy as np
import pandas as pd
import os
import sys
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from matplotlib.ticker import MaxNLocator
import random
import csv
from datetime import datetime
import time

from langchain.schema.runnable import RunnableSequence

In [2]:
# Do imports for deh experiments specific modules
from pathlib import Path

utils_folder = Path("..")
sys.path.append(str(utils_folder))

utils_folder = Path("../src/deh")
sys.path.append(str(utils_folder))

utils_folder = Path(".")
sys.path.append(str(utils_folder))

import squad_scoring
import deh_prompts
import deh_vector_store
import deh_squad_data
import deh_hyde
import deh_experiments_config
globals().update(deh_experiments_config.__dict__)
import deh_globals
globals().update(deh_globals.__dict__)
# from deh_llm import get_llm
import deh_llm
globals().update(deh_llm.__dict__)

##### Loading SQuAD data

In [3]:
csv_file_path = f"{DATA_ROOT}/qas_with_contexts.csv"

print(f"Loading squad data...\n")

column_names = ["title", "squad_context", "qid", "question", "is_impossible", "answer"]
squad_raw = pd.read_csv(f"{DATA_ROOT}/squad_raw.csv", names=column_names, skiprows=1)
df_squad_raw = pd.DataFrame(squad_raw)
print(f"Number of raw entries in squad_raw: {len(df_squad_raw)}")

df_titles = pd.DataFrame(df_squad_raw['title'].unique(), columns=["title"])
print(f"Number of unique titles: {len(df_titles)}")

df_contexts = pd.DataFrame(df_squad_raw['squad_context'].unique(), columns=["squad_context"])
print(f"Number of unique contexts: {len(df_contexts)}")

df_qas = df_squad_raw[['title', 'squad_context', 'qid', 'question', 'is_impossible']].drop_duplicates()
df_qas = df_qas.reset_index(drop=True)
print(f"Number of unique questions: {len(df_qas)}")

df_squad_answers = df_squad_raw[['qid', 'question', 'answer']].drop_duplicates()
print(f"Number of unique answers: {len(df_squad_answers)}")           
            

Loading squad data...

Number of raw entries in squad_raw: 26232
Number of unique titles: 35
Number of unique contexts: 1204
Number of unique questions: 11858
Number of unique answers: 16209


##### Intialize the Vector Store (Chroma; Milvus not yet included)

In [4]:
contexts = list(df_contexts["squad_context"].values)
if CHUNK_SQUAD_DATASET:    
    deh_vector_store.chunk_squad_dataset(contexts, dataset, CHUNK_SIZE, CHUNK_OVERLAP)
else:
    print("Chunking not foreseen. Skipping chunking.")

# Intiialize the Chroma vector store
vector_store = deh_vector_store.get_vector_store(DEFAULT_CHROMA_PREFIX, DEFAULT_CHROMA_COLLECTION)

Chunking not foreseen. Skipping chunking.


##### Loading qas with contexts data (if data is not to be restored)

In [None]:
# If not restoring qas_with_contexts from the CSV file, then read
# the data from the csv file (i.e. it exists and is correct)
if not RESTORE_QAS_WITH_CONTEXTS:
    # Loading the question contexts from the CSV file
    qas_with_contexts_csv_file_path = f"{DATA_ROOT}/qas_with_contexts.csv"

    df_qas_with_contexts = pd.read_csv(qas_with_contexts_csv_file_path) #, names=column_names)

    # drop the answer column if it exists, since it leads to duplicates
    if 'answer' in df_qas_with_contexts.columns:
        df_qas_with_contexts = df_qas_with_contexts.drop(columns=['answer'])

    df_qas_with_contexts = df_qas_with_contexts.drop_duplicates()
    print(f"Rows in dataframe df_qas_with_contexts: {len(df_qas_with_contexts)}")
    hyde_articles_cnt = df_qas_with_contexts['hyde_article'].notna().sum()
    hyde_based_contexts_cnt = df_qas_with_contexts['hyde_based_context'].notna().sum()
    print(f"Number of questions with Hyde articles: {hyde_articles_cnt}")
    print(f"Number of questions with Hyde based contexts: {hyde_based_contexts_cnt}")



##### Restore qas with contexts (if configured); alternatively refreseh contexts

In [None]:
def restore_df_qas_with_contexts_file(csv_file_path, df_qas):

    # Add columns hyde_article, question_context, hyde_based_context to df_qas_with_contexts
    if not 'hyde_article' in df_qas.columns:
        df_qas['hyde_article'] = np.nan
    if not 'question_context' in df_qas.columns:
        df_qas['question_context'] = np.nan
    if not 'hyde_based_context' in df_qas.columns:
        df_qas['hyde_based_context'] = np.nan

    # Get Hyde data
    hyde_based_context_path = f"{HYDE_BASED_CONTEXTS_ROOT}/hyde_based_contexts.csv"
    df_hyde_based_contexts = pd.read_csv(hyde_based_context_path)

    print(f"Rows in dataframe df_hyde_based_contexts: {len(df_hyde_based_contexts)}")
    hyde_articles_cnt = df_hyde_based_contexts['hyde_article'].notna().sum()
    hyde_based_contexts_cnt = df_hyde_based_contexts['hyde_based_context'].notna().sum()

    # Merge df_qas with df_hyde_based_contexts based on the 'qid' column
    merged = df_qas.merge(df_hyde_based_contexts, on='qid', how='left', suffixes=('', '_df_hyde_based_contexts'))
    df_qas['hyde_article'] = merged['hyde_article_df_hyde_based_contexts']

    print(f"Number of questions with Hyde articles: {hyde_articles_cnt}")
    print(f"Number of questions with Hyde based contexts: {hyde_based_contexts_cnt}")

    df_qas.to_csv(csv_file_path, header=True, index=False)



In [None]:
# if RESTORE_QAS_WITH_CONTEXTS:
#     restore_df_qas_with_contexts(csv_file_path, df_qas)

#     # Get Hyde data
#     hyde_based_context_path = f"{HYDE_BASED_CONTEXTS_ROOT}/hyde_based_contexts.csv"
#     df_hyde_based_contexts = pd.read_csv(hyde_based_context_path)

#     print(f"Rows in dataframe df_hyde_based_contexts: {len(df_hyde_based_contexts)}")
#     hyde_articles_cnt = df_hyde_based_contexts['hyde_article'].notna().sum()
#     hyde_based_contexts_cnt = df_hyde_based_contexts['hyde_based_context'].notna().sum()

#     # Merge df1 with df2 based on the 'qid' column
#     merged = df_qas.merge(df_hyde_based_contexts, on='qid', how='left', suffixes=('', '_df_hyde_based_contexts'))
#     df_qas['hyde_article'] = merged['hyde_article_df_hyde_based_contexts']

#     print(f"Number of questions with Hyde articles: {hyde_articles_cnt}")
#     print(f"Number of questions with Hyde based contexts: {hyde_based_contexts_cnt}")


In [None]:
# TODO an df_qas_with_contexts anpassen !!!!!!!!

if RESTORE_QAS_WITH_CONTEXTS:
    restore_df_qas_with_contexts_file(csv_file_path, df_qas)

# Refresh question contexts (normal contexts and hyde-based contexts)
# TODO: might be useful to first empty the two columns in the dataframe
if RESTORE_QAS_WITH_CONTEXTS or REFRESH_QUESTION_CONTEXTS or REFRESH_HYDE_CONTEXTS:
    print(f"Re-Generating contexts for the dataset and persisting the data...")
    list_of_qas = df_qas.to_dict(orient='records')

    for i, qa in enumerate(list_of_qas):
        # print(i)
        if i %100 == 0:
            print(f"Processing question {i}...")

        if RESTORE_QAS_WITH_CONTEXTS or REFRESH_QUESTION_CONTEXTS:
            question = qa["question"]
            
            top_docs = vector_store.similarity_search(
                query = question,
                k = VECTOR_STORE_TOP_K,
            )
            qa["question_context"] = " ".join([top_doc.page_content for top_doc in top_docs])

        if RESTORE_QAS_WITH_CONTEXTS or REFRESH_HYDE_CONTEXTS:
            hyde_article = qa["hyde_article"]
            #print(f"hyde_article: {hyde_article}")
            if pd.isna(hyde_article):
                hyde_article = ""
            elif len(hyde_article) == 0:
                hyde_article = ""
            else:
                top_docs = vector_store.similarity_search(
                    query = hyde_article,
                    k = VECTOR_STORE_TOP_K,
                )

                qa["hyde_based_context"] = " ".join([top_doc.page_content for top_doc in top_docs])

    df_qas = pd.DataFrame(list_of_qas)
    df_qas.to_csv(csv_file_path, header=True, index=False)

In [None]:
df_qas# Convert DataFrame to a list of dictionaries
list_of_qas = df_qas.to_dict(orient='records')
list_of_dicts[:3]

##### Get Hyde data (articles and Hyde-based contetxts)

In [None]:
# hyde_based_context_path = f"{HYDE_BASED_CONTEXTS_ROOT}/hyde_based_contexts.csv"
# df_hyde_based_contexts = pd.read_csv(hyde_based_context_path)

# print(f"Rows in dataframe df_hyde_based_contexts: {len(df_hyde_based_contexts)}")
# hyde_articles_cnt = df_hyde_based_contexts['hyde_article'].notna().sum()
# hyde_based_contexts_cnt = df_hyde_based_contexts['hyde_based_context'].notna().sum()
# print(f"Number of questions with Hyde articles: {hyde_articles_cnt}")
# print(f"Number of questions with Hyde based contexts: {hyde_based_contexts_cnt}")

##### Show names of all dataframes

In [None]:
import copy 

print("Names of Dataframes and their lenghts:\n")
global_keys_copy = copy.deepcopy(list(globals().keys()))

my_l = [{n: len(globals()[n])} for n in global_keys_copy if n.startswith("df_")]
df_dfs = pd.DataFrame([(k, v) for d in my_l for k, v in d.items()], columns=['df name', 'rows'])
df_dfs


##### Intialize the Vector Store (Chroma; Milvus not yet included)

In [None]:
# contexts = list(df_contexts["squad_context"].values)
# if CHUNK_SQUAD_DATASET:    
#     deh_vector_store.chunk_squad_dataset(contexts, dataset, CHUNK_SIZE, CHUNK_OVERLAP)
# else:
#     print("Chunking not foreseen. Skipping chunking.")

In [None]:
# Intiialize the Chroma vector store
# vector_store = deh_vector_store.get_vector_store(DEFAULT_CHROMA_PREFIX, DEFAULT_CHROMA_COLLECTION)

##### Get hyde-based contexts

Always get hyde-based contexts that already exist and these to the qas dataset

In [None]:
# if not READ_QAS_FROM_FILE:
#     hyde_based_context_path = f"{HYDE_BASED_CONTEXTS_ROOT}/hyde_based_contexts.csv"
#     hyde_based_contexts, questions_already_processed = deh_hyde.get_hyde_based_contexts(hyde_based_context_path)
#     print(f"Number of questions with hyde-based context: {len(questions_already_processed)}")

In [None]:
# if not READ_QAS_FROM_FILE:
#     # Now add the hyde-based contexts to the qas dataset
#     def get_hyde_based_info_from_qid(qid):
#         for hbc in hyde_based_contexts:
#             if hbc["qid"] == qid:
#                 return (hbc["hyde_article"], hbc["hyde_based_context"])
#         return ("", "")

#     for qa in qas:
#         qid = qa["qid"]
#         hyde_article, hyde_based_context = get_hyde_based_info_from_qid(qid)
#         qa["hyde_article"] = hyde_article
#         qa["hyde_based_context"] = hyde_based_context


##### Refresh Contexts (if configured in deh_globals.py)

Questions:

- if configured, either: generate question and hyde contexts and then persist
- or: read the question contexts from a .csv file


In [None]:
def persist_question_contexts(qas, csv_file_path):

    # Write the the qas dataset including the question contexts to a CSV file
    with open(csv_file_path, mode="w", newline="", encoding="utf-8") as file:
        fieldnames = ["title", "context", "qid", "question", "is_impossible",
                      "answer", "hyde_article", "hyde_based_context", "question_context"]
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        writer.writeheader()   # Write the header row
        writer.writerows(qas)  # Write the data rows

    print(f"Data successfully written to {csv_file_path}")


##### Define functions that are needed for experiments

In [None]:
# Create the runnable chain
def get_runnable_chain(current_query_prompt, llm):
    runnable_chain = RunnableSequence(current_query_prompt | llm)
    return runnable_chain

In [None]:
# Get the Hyde context for a question
def get_hyde_based_context(question):
    hyde_based_context = df_qas_with_contexts[df_qas_with_contexts["question"] == question]
    if hyde_based_context.empty:
        return None
    else:
        return hyde_based_context["hyde_based_context"].values[0]
    

In [None]:
# # generate the LLM answers, using a runnable chain and the sample of questions provided
# def generate_llm_answers(runnable_chain, qas_sample, hyde=False):
    
#     preds = {}

#     sample_size = len(qas_sample)
#     print(f"sample_size: {sample_size}")

#     for i, qa in enumerate(qas_sample):

#         question = qa["question"]
#         if hyde:
#             #context = qa["hyde_context"]
#             context = get_hyde_based_context(question)
#         else:
#             context = qa["vector_store_context"]
            
#         # print(f"question --> {question}")
#         # print(context)
#         response = runnable_chain.invoke({"context": context, "question": question})
                
#         qid = squad_scoring.get_qid_from_question(question, dataset)
        
#         if response.content.upper() == "DONT KNOW":
#             llm_answer = ""
#         else:
#             llm_answer = response.content

#         preds[qid] = llm_answer
#         qas_sample[i]["llm_answer"] = llm_answer

#     return preds


In [None]:
#%%capture

# Get the metrics for a set of predictions (preds) that have been generated in a run
def get_squad_metrics(dataset, preds, verbose=False):
    squad_metrics = squad_scoring.calc_squad_metrics(dataset, preds);
    return squad_metrics["precision"], squad_metrics["recall"], squad_metrics["f1"]


In [None]:
# Calculate the mean and confidence interval for a list of scores
# TODO: Check if this is calculation is correct !!
def calculate_mean_confidence_interval(scores_l):

    # Calculate mean
    mean = np.mean(scores_l)

    # Calculate 95% confidence interval
    sample_std_dev = np.std(scores_l, ddof=1)
    margin_of_error = 1.96 * sample_std_dev
    ci = (max(mean - margin_of_error, 0), min(mean + margin_of_error, 100))
    # if ci[0] < 0:
    #     ci = (0, ci[1])
    # if ci[1] > 100:
    #     ci = (ci[0], 100)

    return mean, ci

In [None]:


# Generate a histogram for a list of scores and persist it
def generate_histogram(scores_l, mean, ci, results_folder_name, experiment_name):

    plt.clf
    plt.hist(scores_l, bins=30, density=False, edgecolor='black', alpha=0.6, color = 'lightblue' ) # color='aquamarine')
    plt.xlim(0, 100)

    plt.title(f"F1-Scores for {experiment_name} - (Bootstraps: {BOOTSTRAPS_N} - Sample Size: {SAMPLE_SIZE})", fontsize=10)
    plt.xlabel("F1-Score")
    plt.ylabel("Frequency")

    # Add a vertical line for the mean
    max_len = 6
    mean_label = f"{mean: .2f}".rjust(max_len)
    plt.axvline(mean, color='red', linestyle='dotted', linewidth=2, label=f'Mean F1:          {mean_label}')

    # Add vertical lines for the 95% confidence interval
    lower = f"{ci[0]: .2f}".rjust(max_len)
    upper = f"{ci[1]: .2f}".rjust(max_len)
    plt.axvline(ci[0], color='orange', linestyle='dashdot', linewidth=1.5, label=f"95% CI Lower:  {lower}")
    plt.axvline(ci[1], color='orange', linestyle='dashdot', linewidth=1.5, label=f"95% CI Upper:  {upper}")

    ax = plt.gca()  # Get current axis
    ax.xaxis.set_major_locator(MultipleLocator(10))  # Major ticks every 10 units
    ax.xaxis.set_minor_locator(MultipleLocator(5))   # Minor ticks every 5 units
    ax.yaxis.set_major_locator(MultipleLocator(10))   # Example: Major ticks every 5 units on the y-axis
    #ax.yaxis.set_major_formatter(PercentFormatter(xmax=len(data)))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=10))
    
    # Customize grid for major and minor ticks
    ax.grid(which='major', color='gray', linestyle='--', linewidth=0.5)
    ax.grid(which='minor', color='lightgray', linestyle=':', linewidth=0.5)

    # Add a legend
    plt.legend(prop={'family': 'monospace', 'size': 10})
    plt.legend(loc='upper right', fontsize=10)

    plt.savefig(os.path.join(results_folder_name, f"{experiment_name}_{BOOTSTRAPS_N}_{SAMPLE_SIZE}"))
    return plt

##### Define functions for Bootstrapping

In [None]:
# Get the current timestamp and format the start timestamp as a string
def get_timestamp_as_string():
    start_timestamp = datetime.now()
    start_timestamp_str = start_timestamp.strftime('%Y%m%d_%H%M%S')
    return start_timestamp_str

In [None]:
def calculate_scores(qas, query_prompt_idx, experiment_name, context_needed=False, suppress_answers=False):

    # Create the chain
    current_query_prompt = deh_prompts.query_prompts[query_prompt_idx]
    print(f"current_query_prompt = {current_query_prompt.template}\n")
    llm = get_llm(current_query_prompt)
    runnable_chain = get_runnable_chain(current_query_prompt, llm)

    # Generate the LLM answers for all questions and calculate per-answer metrics 
    # add each answer to the all_preds
    preds = {}
    all_preds= {}
    for i, qa in enumerate(qas):

        if i % 10 == 0:
            print(f"Processing question {i}...")

        qid = qa["qid"]
        question = qa["question"]

        if context_needed:
            if experiment_name == "BASIC_RAG_HYDE":
                context = qa["hyde_based_context"]
            elif experiment_name == "FULL_RAG":
                context = qa["question_context"] + "\n\n" + qa["hyde_based_context"]
            else:
                context = qa["question_context"]
            response = runnable_chain.invoke({"question": question, "context": context})
        else:
            response = runnable_chain.invoke({"question": question})

        llm_answer = response.content

        if llm_answer.upper() == "DONT KNOW":
            if suppress_answers:
                continue
            else:
                llm_answer = ""

        preds[qid] = llm_answer
        all_preds[qid] = llm_answer

        scores = squad_scoring.calc_squad_metrics(dataset, preds)
        f1 = scores["f1"]
        precision = scores["precision"]
        recall = scores["recall"]

        preds = {}
        qa[f"{experiment_name.lower()}_llm_answer"] = llm_answer
        qa[f"{experiment_name.lower()}_f1"] = f1
        qa[f"{experiment_name.lower()}_precision"] = precision
        qa[f"{experiment_name.lower()}_recall"] = recall

    return all_preds
        

In [None]:
def persist_results(results_folder_name, experiment_name, df):

    df.to_csv(f"{results_folder_name}/qas_{experiment_name.lower()}_scores.csv", header=True, index=False)

In [None]:
def clip_scores(scores_l, clip_perc):
    scores_sorted_l = sorted(scores_l)

    clip_cnt = int(len(scores_sorted_l) * clip_perc / 100)
    print(f"clip_cnt: {clip_cnt}")

    # Now clip both ends by clip_perc percent
    clipped_scores_l = scores_sorted_l[clip_cnt:-clip_cnt] if clip_cnt > 0 else scores_sorted_l
    return clipped_scores_l


In [None]:
def do_bootstrapping(scores_l, results_folder_name, experiment_name, bootstraps_n = BOOTSTRAPS_N):
    
    mu_hats = []
    n = len(scores_l)
    scores_l = clip_scores(scores_l, 2.5)
    # print(f"scores_l: {scores_l}")
    for i in range(bootstraps_n):
        if i % 100 == 0:
            print(f"Processing sample {i}...")
        bootstrap_sample = random.choices(scores_l, k=n) # sample with replacement
        mu_hat = np.mean(bootstrap_sample)
        mu_hats.append(mu_hat)

    bootstraps_mean, ci = calculate_mean_confidence_interval(mu_hats)
    plt = generate_histogram(mu_hats, bootstraps_mean, ci, results_folder_name, experiment_name)
    plt.show();

In [None]:
def create_results_folder(experiment_name):
    start_timestamp_str = get_timestamp_as_string()
    results_folder_name = f"{RESULTS_ROOT}/{experiment_name}/results_{start_timestamp_str}"
    if not os.path.exists(results_folder_name):
        os.makedirs(results_folder_name, exist_ok=True)
    return results_folder_name

In [None]:
def conduct_experiment(qas, experiment_name, query_prompt_idx, 
                       context_needed=False, hyde_context_needed=False, 
                       suppress_answers=False):

    start_time = time.perf_counter()

    print(f"============= Creating results folder for {experiment_name} =============")
    results_folder_name = create_results_folder(experiment_name)
    
    print(f"============= Calculating scores for {experiment_name} =============")
    print(f"SAMPLE_SIZE: {SAMPLE_SIZE}\n")
    all_preds = calculate_scores(qas, query_prompt_idx, experiment_name, context_needed, suppress_answers)

    print(f"============= Persisting results for {experiment_name} =============")
    df = pd.DataFrame(qas)
    persist_results(results_folder_name, experiment_name, df)

    print(f"\n============= Bootstrapping for {experiment_name} =============")
    print(f"BOOTSTRAPS_N: {BOOTSTRAPS_N}")

    do_bootstrapping(df[f"{experiment_name.lower()}_f1"].dropna().tolist(), results_folder_name, experiment_name, BOOTSTRAPS_N)

    end_time = time.perf_counter()
    execution_time = end_time - start_time
    execution_times_entry = {}
    execution_times_entry["experiment_name"] = experiment_name
    execution_times_entry["execution_time"] = execution_time
    execution_times_entry["sample_size"] = SAMPLE_SIZE
    execution_times_entry["bootstrap_n"] = BOOTSTRAPS_N
    execution_times_l.append(execution_times_entry)
    return execution_time

##### Creating a sample from qas for bootstrapping and bootstrapping with Hyde (will be used for all experiments)

In [None]:
df_qas_for_bootstrapping = df_qas_with_contexts[['qid', 'question', 'question_context', 'hyde_based_context']]
df_qas_for_bootstrapping_sample = df_qas_for_bootstrapping.sample(n=SAMPLE_SIZE, replace=True)
ldict_qas_for_boostrapping_sample = df_qas_for_bootstrapping_sample.to_dict(orient='records')

df_qas_for_hyde_bootstrapping = df_qas_for_bootstrapping[df_qas_for_bootstrapping['hyde_based_context'].notna()]
df_qas_for_hyde_bootstrapping_sample = df_qas_for_hyde_bootstrapping.sample(n=SAMPLE_SIZE, replace=True)
ldict_qas_for_hyde_bootstrapping_sample = df_qas_for_hyde_bootstrapping_sample.to_dict(orient='records')

sample_ldicts = [ldict_qas_for_boostrapping_sample, ldict_qas_for_hyde_bootstrapping_sample]

##### Conducting all experiments

In [None]:
for index, row in df_experiments.iterrows():
    if not row["include"]:
        continue

    sample_ldict = sample_ldicts[row["sample_ldicts_idx"]]
    experiment_name = row["name"]
    query_prompt_idx = row["query_prompt_idx"]
    context_needed = row["context_needed"]
    hyde_context_needed = row["hyde_context_needed"]
    suppress_answsers = row["suppress_answers"]
    
    conduct_experiment(sample_ldict, experiment_name, query_prompt_idx,
                       context_needed, hyde_context_needed, suppress_answsers)

print(f"\n================== Execution times (in seconds): ==================================\n")
df_execution_times = pd.DataFrame(execution_times_l, 
                                  columns=['experiment_name', 'execution_time', 'sample_size', 'bootstrap_n'])

pd.options.display.float_format = "{:.2f}".format
df_execution_times.head(100)                

In [None]:
# Reset the display format for floats
pd.reset_option('display.float_format')

## Judges 

In [None]:
# def generate_llm_judges_score(runnable_chain, qas_sample, hyde=False):
    
#     preds = {}

#     sample_size = len(qas_sample)
#     print(f"sample_size: {sample_size}")

#     for i, qa in enumerate(qas_sample):

#         question = qa["question"]
#         if hyde:
#             #context = qa["hyde_context"]
#             context = get_hyde_based_context(question)
#         else:
#             context = qa["vector_store_context"]
            
#         # print(f"question --> {question}")
#         # print(context)
#         response = runnable_chain.invoke({"context": context, "question": question})
                
#         qid = squad_scoring.get_qid_from_question(question, dataset)
        
#         if response.content.upper() == "DONT KNOW":
#             llm_answer = ""
#         else:
#             llm_answer = response.content

#         preds[qid] = llm_answer
#         qas_sample[i]["llm_answer"] = llm_answer

#     return preds

In [None]:
judge_current_query_prompt = deh_prompts.query_prompts[4]
print(f"judge_current_query_prompt = {judge_current_query_prompt.template}\n")

In [None]:
def convert_string_to_answer_tuple(input_string):
    # Trim whitespace
    parts = input_string.strip()
    # Remove parentheses
    parts = parts.strip("()")
    # Split on comma
    parts = parts.split(", ")
    if len(parts) != 2:
        #raise ValueError("Invalid answer string")
        return ("NO", 0)

    answer = parts[0].strip().upper()  # "Yes" or "No"
    if answer not in ["YES", "NO"]:
        #raise ValueError("Invalid answer string")
        return ("NO", 0)
    try:
        score = float(parts[1])  # Convert score to a float
    except ValueError:
        return ("NO", 0)
    
    return (answer, score)

In [None]:
# for all_judge_verdicts in zip(zip(*all_judge_answers), zip(*all_judge_scores)):
#     print(get_judge_verditcs)

def get_majority_verdict(judge_verdicts):

    judge_answers = judge_verdicts[0]
    yes_count = judge_answers.count('YES')
    no_count = judge_answers.count('NO')

    if no_count > yes_count:
        return "NO", 0.0
    
    judge_scores = list(judge_verdicts[1])
    return "YES", sum(judge_scores) / yes_count


In [None]:
def get_majority_verditcs(all_judge_answers, all_judge_scores):

    majority_verdicts = []
    for per_question_judge_verdicts in zip(zip(*all_judge_answers), zip(*all_judge_scores)):
        majority_verdict = get_majority_verdict(per_question_judge_verdicts)
        majority_verdicts.append(majority_verdict)
        print(per_question_judge_verdicts)
        print(majority_verdict)
        print("")

    return majority_verdicts


In [None]:
judge_llms = [MISTRAL_LATEST, GEMMA2_9B, QWEN2_5_7B]

all_judge_answers = []
all_judge_scores = []

print("--------------------------------------------------------")
for judge_llm in judge_llms:
    judge_chain = judge_current_query_prompt | get_llm(judge_current_query_prompt, True, judge_llm)
    judge_answers = []
    judge_scores = []

    for i, qa in enumerate(sample_ldicts[0]):

        if i % 5 == 0:
            print(f"Processing question {i}...")

        question = qa["question"]
        context = qa["question_context"]
        answer_key = experiment_name.lower() + "_llm_answer"
        answer = qa[answer_key]
        # print(f"Question: {question}")
        # print(f"Answer: {answer}")
        #for i in range(3):
        response = judge_chain.invoke({"context": context, "question": question, "answer": answer})
        a, s = convert_string_to_answer_tuple(response.content)
        judge_answers.append(a)
        judge_scores.append(s)
        
        # print(f"Judge Answer: {a}")
        # print(f"Score: {s}")
        # print("") 

    # print(f"judge_answers: {judge_answers}")
    # print(f"judge_scores: {judge_scores}")    

    all_judge_answers.append(judge_answers)
    all_judge_scores.append(judge_scores)

# print(f"all_judge_answers: {all_judge_answers}")
# print(f"all_judge_scores: {all_judge_scores}")  

final_judges_verdicts = get_majority_verditcs(all_judge_answers, all_judge_scores)    

In [None]:
final_judges_verdicts

THRESHOLD = 0.7
filter = [True if val[1] <= THRESHOLD
          else False for val in final_judges_verdicts]

sum(filter)

In [None]:
len(final_judges_verdicts)