In [1]:
from langfuse_config import langfuse, langfuse_handler
from llm import llm, MODEL_NAME
from graph import graph
import dotenv
from utils import format_prompt
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
import pprint

import os

dotenv.load_dotenv()

* 'allow_population_by_field_name' has been renamed to 'populate_by_name'
* 'smart_union' has been removed


True

In [2]:
# import pprint

# inputs = {
#     "messages": [
#         HumanMessage('what was the percentage change in the net cash from operating activities from 2008 to 2009'),
#     ]
# }

# for error in graph.stream(inputs, config={"callbacks": [langfuse_handler]}):
#     for key, value in error.items():
#         print(f"Output from node '{key}':")
#         print("---")
#         pprint.pprint(value, indent=2, width=80, depth=None)
#     print()
#     print("---")
#     print()

In [3]:
dataset = langfuse.get_dataset("convfinqa-train")
from datetime import datetime
from nodes import CHEATING_RETRIEVAL, DISABLE_GENERATION
from prompts import eval_prompt_template

from tqdm.auto import tqdm

MODEL_NAME = os.getenv("LLM_MODEL", "llama3.1")

def relative_score(a, b, power=2):
    """
    Relative difference between two numbers
    
    We also apply a power penalty to penalize larger differences more.
    """
    if a == b:
        return 1.0
    else:
        return 1 - ((abs(a - b) / max(abs(a), abs(b))) ** power)
    
def retrieval_precision_score(predicted: list[str], expected: str) -> float:
    """
    Number of relevant documents retrieved / number of documents retrieved    
    In case of ConvFinQA, we only have 1 expected document. 
    """

    try:
        return float(expected in predicted) / len(predicted)
    except ZeroDivisionError:
        return 0

def retrieval_recall_score(predicted: list[str], expected: str) -> float:
    """
    Number of relevant documents retrieved / number of relevant documents
    
    In case of ConvFinQA, we only have 1 expected document. 
    So if the document is in the predicted set, we get a recall of 1
    Otherwise, we get a recall of 0
    """
    return float(expected in predicted)


def correctness_score(input, predicted, expected):
    if DISABLE_GENERATION:
        return None
    
    predicted = predicted.lower().strip()
    expected = expected.lower().strip()

    # Base cases, we don't need to use LLM for that 
    if predicted == "" and expected != "":
        return 0
    if predicted != "" and expected == "":
        return None
    if predicted == expected:
        return 1
    
    # Compare numbers, allow for percentages, dollars signs
    try:
        expected_parsed = float(expected.replace('%', 'e-2').replace("$", "").replace(",", "").replace(" ", ""))
        expected_parsed_2 = float(expected.replace('%', '').replace("$", "").replace(",", "").replace(" ", ""))
        predicted_parsed = float(predicted.replace('%', 'e-2').replace("$", "").replace(",", "").replace(" ", ""))
        predicted_parsed_2 = float(predicted.replace('%', '').replace("$", "").replace(",", "").replace(" ", ""))
        return max(
            relative_score(predicted_parsed, expected_parsed),
            relative_score(predicted_parsed_2, expected_parsed_2),
            relative_score(predicted_parsed, expected_parsed_2),
            relative_score(predicted_parsed_2, expected_parsed),
        )
    except Exception as e:
        pass

    # Otherwise, use LLM
    print("Using LLM for score generation...")
    prompt = eval_prompt_template.format(question=input, actual_answer=predicted, expected_answer=expected)
    out = llm.completions.create(model=MODEL_NAME, prompt=format_prompt(prompt), max_tokens=10, temperature=0)
    out_text = out.choices[0].text
    out_text = out_text.replace("<OUTPUT>", "").replace("</OUTPUT>", "")
    try:
        float_out = float(out_text)
    except:
        float_out = None # Error
        print(f"Error generating score (generated: {out_text})")
    return float_out

answer_correctness_scores = []
retrieval_precision_scores = []
retrieval_recall_scores = []
reranker_precision_scores = []
reranker_recall_scores = []
context_precision_scores = []
context_recall_scores = []

run_name = f"{datetime.now().strftime('%Y%m%d%H%M%S')}"
for item in tqdm(dataset.items[:1000]):
    handler = item.get_langchain_handler(
        run_name=run_name,
        run_description="RAG evaluation with ground-truth documents",
        run_metadata={"model": MODEL_NAME, "cheating_retrieval": CHEATING_RETRIEVAL, "disable_generation": DISABLE_GENERATION,
                    #    "graph": graph.get_graph().to_json()
                       },
    )

    # Make sure your application function is decorated with @observe decorator to automatically link the trace
    # run your @observe() decorated application on the dataset item input
    inputs = {
        "messages": [
            HumanMessage(item.input),
        ]
    }

            
    output = graph.invoke(inputs, config={"callbacks": [handler]})

    question = output['question']
    answer = output['answer']
    generation = output['generation']

    retrieved_doc_ids = [doc.metadata['id'] for doc in output['documents']]
    assert all(retrieved_doc_ids), "Invalid document IDs"
    expected_doc_id = item.metadata['document']['id']

    retrieval_precision = retrieval_precision_score(retrieved_doc_ids, expected_doc_id)
    retrieval_recall = retrieval_recall_score(retrieved_doc_ids, expected_doc_id)

    reranked_doc_ids = [doc.metadata['id'] for doc in output['reranked_documents']]

    reranker_precision = retrieval_precision_score(reranked_doc_ids, expected_doc_id)
    reranker_recall = retrieval_recall_score(reranked_doc_ids, expected_doc_id)

    source_doc_ids = output['sources']
    print(source_doc_ids)

    context_precision = retrieval_precision_score(source_doc_ids, expected_doc_id)
    context_recall = retrieval_recall_score(source_doc_ids, expected_doc_id)

    retrieval_precision_scores.append(retrieval_precision)
    retrieval_recall_scores.append(retrieval_recall)

    reranker_precision_scores.append(reranker_precision)
    reranker_recall_scores.append(reranker_recall)

    context_precision_scores.append(context_precision)
    context_recall_scores.append(context_recall)

    # Evaluate the output to compare different runs more easily
    correctness = correctness_score(item.input, answer, item.expected_output)

    # Print input, answer, expected output, and the score in a more readable format
    print(f"Input: {item.input}")
    print(f"Predicted Answer: {answer}")
    print(f"Expected Answer: {item.expected_output}")
    print(f"Retrieved Documents: {retrieved_doc_ids}")
    print(f"Expected Document: {expected_doc_id}")
    print(f"Retrieval Precision: {retrieval_precision:.2%}")
    print(f"Retrieval Recall: {retrieval_recall:.2%}")
    print(f"Reranker Precision: {reranker_precision:.2%} ({reranker_precision - retrieval_precision:+.2%})")
    print(f"Reranker Recall: {reranker_recall:.2%} ({reranker_recall - retrieval_recall:+.2%})")
    print(f"Context Precision: {context_precision:.2%} ({context_precision - reranker_precision:+.2%})")
    print(f"Context Recall: {context_recall:.2%} ({context_recall - reranker_recall:+.2%})")
    print(f"Correctness: {correctness:.2%}") if correctness else print(f"Correctness: {correctness}")
    print("-"*50)

    # Show generation for debugging, when retrieval was correct but answer was not
    if not correctness or ((correctness < .5) and (retrieval_recall > 0)):
        print(f"Generation: {generation}")

    print("-"*50)
    handler.trace.update(name="eval", 
                         input=question,
                         output=answer,
                         metadata={"generation": generation,
                                    "documents": output["documents"]})
    if correctness is not None:
        handler.trace.score(
            name="correctness",
            data_type="NUMERIC",
            value=correctness,
            comment=generation,  # reasoning
        )

    # make list of retrieved docs
    # add '+' prefix if the doc was the correct one
    # add no prefix if it was incorrect
    # if no correct docs identfied, add the correct one at the end with '-' prefix 
    doc_selection_display = ""
    for doc in retrieved_doc_ids:
        if doc == expected_doc_id:
            doc_selection_display += '+' + doc
        else:
            doc_selection_display += doc
        doc_selection_display += ", "
    if expected_doc_id not in retrieved_doc_ids:
        doc_selection_display = '-' + doc   
    
    handler.trace.score(
        name="retrieval_precision",
        data_type="NUMERIC",
        value=retrieval_precision,
        comment=doc_selection_display,
    )

    handler.trace.score(
        name="retrieval_recall",
        data_type="BOOLEAN",
        value=retrieval_recall,
        comment=doc_selection_display,
    )

    answer_correctness_scores.append(correctness)
    retrieval_precision_scores.append(retrieval_precision)
    retrieval_recall_scores.append(retrieval_recall)    

    print(f"{'='*50}")

    # Print the final average score in a formatted way
    answer_correctness_scores_non_none = [c for c in answer_correctness_scores if c is not None]
    if len(answer_correctness_scores_non_none) > 0:
        mean_correctness_score = sum(answer_correctness_scores_non_none) / len(answer_correctness_scores_non_none)
        print(f"Average Correctness: {mean_correctness_score:.2%}")

    mean_retrieval_precision_score = sum(retrieval_precision_scores) / len(retrieval_precision_scores)
    mean_retrieval_recall_score = sum(retrieval_recall_scores) / len(retrieval_recall_scores)
    mean_reranker_precision_score = sum(reranker_precision_scores) / len(reranker_precision_scores)
    mean_reranker_recall_score = sum(reranker_recall_scores) / len(reranker_recall_scores)
    mean_context_precision_score = sum(context_precision_scores) / len(context_precision_scores) 
    mean_context_recall_score = sum(context_recall_scores) / len(context_recall_scores)

    print(f"Mean Retrieval Precision: {mean_retrieval_precision_score:.2%}")
    print(f"Mean Retrieval Recall: {mean_retrieval_recall_score:.2%}")
    print(f"Mean Reranker Precision: {mean_reranker_precision_score:.2%} ({mean_reranker_precision_score - mean_retrieval_precision_score:+.2%})")
    print(f"Mean Reranker Recall: {mean_reranker_recall_score:.2%} ({mean_reranker_recall_score - mean_retrieval_recall_score:+.2%})")
    print(f"Mean Context Precision: {mean_context_precision_score:.2%} ({mean_context_precision_score - mean_reranker_precision_score:+.2%})")
    print(f"Mean Context Recall: {mean_context_recall_score:.2%} ({mean_context_recall_score - mean_reranker_recall_score:+.2%})")
    print(f"{'='*50}")

# Flush the langfuse client to ensure all data is sent to the server at the end of the experiment run
langfuse.flush()


  0%|          | 0/124 [00:00<?, ?it/s]

Langfuse client is disabled since no public_key was provided as a parameter or environment variable 'LANGFUSE_PUBLIC_KEY'. See our docs: https://langfuse.com/docs/sdk/python/low-level-sdk#initialize-client


['Single_UPS/2009/page_33.pdf', 'Double_UPS/2009/page_33.pdf', 'Single_UPS/2009/page_33.pdf-2']
Input: what was the difference in percentage cumulative total shareowners 2019 returns for united parcel service inc . versus the standard & poor 2019s 500 index for the five years ended 12/31/10?
Predicted Answer: NO ANSWER
Expected Answer: -1.42%
Retrieved Documents: ['Double_UNP/2014/page_21.pdf', 'Double_UPS/2009/page_33.pdf', 'Single_UPS/2009/page_33.pdf-2', 'Single_UPS/2009/page_33.pdf-1', 'Single_UPS/2009/page_33.pdf-4', 'Single_UPS/2007/page_32.pdf-2', 'Single_UPS/2015/page_108.pdf-3', 'Single_UPS/2017/page_111.pdf-3', 'Single_UPS/2012/page_51.pdf-1', 'Single_UPS/2012/page_51.pdf-4', 'Double_HUM/2017/page_45.pdf', 'Single_C/2016/page_333.pdf-3', 'Single_AAPL/2015/page_24.pdf-3', 'Double_AAPL/2015/page_24.pdf', 'Double_AAPL/2013/page_27.pdf', 'Single_AAPL/2013/page_27.pdf-3', 'Single_AAPL/2013/page_27.pdf-2']
Expected Document: Single_UPS/2010/page_33.pdf-4
Retrieval Precision: 0.00%


In [None]:
output

{'messages': [HumanMessage(content='what was the percentage change in the net cash from operating activities from 2008 to 2009', id='b86af28e-83bf-4cd1-9f62-47a8170469c6'),
  AIMessage(content='[GENERATION DISABLED]', id='129b2dcd-e097-4d58-bb4e-0eabf7288551')],
 'steps': ['extract_question',
  "retrieve('what was the percentage change in the net cash from operating activities from 2008 to 2009')",
  'rerank'],
 'question': 'what was the percentage change in the net cash from operating activities from 2008 to 2009',
 'documents': [Document(metadata={'id': 'Single_INTC/2018/page_48.pdf-3', 'qa': "{'question': 'what was the percentage change in net cash provided by operating activities between 2017 and 2018?', 'answer': '33%', 'explanation': '', 'ann_table_rows': [1], 'ann_text_rows': [], 'steps': [{'op': 'minus2-1', 'arg1': '29432', 'arg2': '22110', 'res': '7322'}, {'op': 'divide2-2', 'arg1': '#0', 'arg2': '22110', 'res': '33%'}], 'program': 'subtract(29432, 22110), divide(#0, 22110)', 