In [1]:
from langfuse_config import langfuse, langfuse_handler
from llm import llm
from graph import graph
import dotenv
import os

dotenv.load_dotenv()



True

In [2]:
import pprint
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage

inputs = {
    "messages": [
        HumanMessage('what was the percentage change in the net cash from operating activities from 2008 to 2009'),
    ]
}

for error in graph.stream(inputs, config={"callbacks": [langfuse_handler]}):
    for key, value in error.items():
        print(f"Output from node '{key}':")
        print("---")
        pprint.pprint(value, indent=2, width=80, depth=None)
    print()
    print("---")
    print()

Langfuse client is disabled since no public_key was provided as a parameter or environment variable 'LANGFUSE_PUBLIC_KEY'. See our docs: https://langfuse.com/docs/sdk/python/low-level-sdk#initialize-client


Output from node 'extract_question':
---
{ 'question': 'what was the percentage change in the net cash from operating '
              'activities from 2008 to 2009',
  'steps': ['extract_question']}

---

Output from node 'retriever':
---
{ 'documents': [ Document(id='Single_JKHY/2009/page_28.pdf-3', metadata={'id': 'Single_JKHY/2009/page_28.pdf-3', 'qa': "{'question': 'what was the percentage change in the net cash from operating activities from 2008 to 2009', 'answer': '14.1%', 'explanation': '', 'ann_table_rows': [6], 'ann_text_rows': [], 'steps': [{'op': 'minus2-1', 'arg1': '206588', 'arg2': '181001', 'res': '25587'}, {'op': 'divide2-2', 'arg1': '#0', 'arg2': '181001', 'res': '14.1%'}], 'program': 'subtract(206588, 181001), divide(#0, 181001)', 'gold_inds': {'table_6': '2008 the net cash from operating activities of year ended june 30 2009 2008 is $ 206588 ; the net cash from operating activities of year ended june 30 2009 2008 is $ 181001 ; the net cash from operating activities o

In [5]:
dataset = langfuse.get_dataset("convfinqa-train")
from datetime import datetime
import time
from nodes import CHEATING_RETRIEVAL, DISABLE_GENERATION
from prompts import eval_prompt_template

from tqdm.auto import tqdm

MODEL_NAME = os.getenv("LLM_MODEL", "llama3.1")

def retrieval_precision_score(predicted: list[str], expected: str) -> float:
    """
    Number of relevant documents retrieved / number of documents retrieved    
    In case of ConvFinQA, we only have 1 expected document. 
    """

    return float(expected in predicted) / len(predicted)
    # return len(set(predicted).intersection(set(expected))) / len(predicted)

def retrieval_recall_score(predicted: list[str], expected: str) -> float:
    """
    Number of relevant documents retrieved / number of relevant documents
    
    In case of ConvFinQA, we only have 1 expected document. 
    So if the document is in the predicted set, we get a recall of 1
    Otherwise, we get a recall of 0
    """
    return float(expected in predicted)


def check_error_llm(input, predicted, expected):
    # Base cases, we don't need to use LLM for that 
    if predicted == "" and expected != "":
        return 0
    if predicted != "" and expected == "":
        return -1
    if predicted == expected:
        return 1
    
    # Compare numbers, allow for percentages, dollars signs
    try:
        expected_parsed = float(expected.replace('%', 'e-2').replace("$", ""))
        expected_parsed_2 = float(expected.replace('%', '').replace("$", ""))
        predicted_parsed = float(predicted.replace('%', 'e-2').replace("$", ""))
        predicted_parsed_2 = float(predicted.replace('%', '').replace("$", ""))
        if set(predicted_parsed, predicted_parsed_2).intersection(
            set(expected_parsed, expected_parsed_2)
            ):
            return 1
    except:
        pass

    # Otherwise, use LLM
    prompt = eval_prompt_template.format(question=input, actual_answer=predicted, expected_answer=expected)
    out = llm.completions.create(model=MODEL_NAME, prompt=prompt, max_tokens=5)
    try:
        float_out = float(out.choices[0].text)
    except:
        float_out = -1 # Error
        print(f"Error generating score (generated: {out.choices[0].text})")
    return float_out

answer_correctness_scores = []
retrieval_precision_scores = []
retrieval_recall_scores = []
run_name = f"{datetime.now().strftime('%Y%m%d%H%M%S')}"
for item in tqdm(dataset.items[:1000]):
    handler = item.get_langchain_handler(
        run_name=run_name,
        run_description="RAG evaluation with ground-truth documents",
        run_metadata={"model": MODEL_NAME, "cheating_retrieval": CHEATING_RETRIEVAL, "disable_generation": DISABLE_GENERATION,
                    #    "graph": graph.get_graph().to_json()
                       },
    )

    # Make sure your application function is decorated with @observe decorator to automatically link the trace
    # run your @observe() decorated application on the dataset item input
    inputs = {
        "messages": [
            HumanMessage(item.input),
        ]
    }

            
    output = graph.invoke(inputs, config={"callbacks": [handler]})

    question = output['question']
    answer = output['answer']
    generation = output['generation']

    retrieved_doc_ids = [doc.metadata['id'] for doc in output['documents']]
    assert all(retrieved_doc_ids), "Invalid document IDs"
    expected_doc_id = item.metadata['document']['id']

    retrieval_precision = retrieval_precision_score(retrieved_doc_ids, expected_doc_id)
    retrieval_recall = retrieval_recall_score(retrieved_doc_ids, expected_doc_id)

    retrieval_precision_scores.append(retrieval_precision)
    retrieval_recall_scores.append(retrieval_recall)

    # Evaluate the output to compare different runs more easily
    correctness = check_error_llm(item.input, answer, item.expected_output)

    # Print input, answer, expected output, and the score in a more readable format
    print(f"Input: {item.input}")
    # print(f"Predicted Answer: {answer}")
    print(f"Expected Document: {expected_doc_id}")
    print(f"Retrieved Documents: {retrieved_doc_ids}")
    # print(f"Expected Answer: {item.expected_output}")
    print(f"Retrieval Precision: {retrieval_precision}")
    print(f"Retrieval Recall: {retrieval_recall}")
    # print(f"Score: {correctness}\n" + "-"*50)

    # Show generation for debugging, when retrieval was correct but answer was not
    if (correctness < .5) and (retrieval_recall > 0):
        print(f"Generation: {generation}")

    print("-"*50)
    handler.trace.update(name=question, 
                         input=question,
                         output=answer,
                         metadata={"generation": generation,
                                    "documents": output["documents"]})
    handler.trace.score(
        name="correctness",
        data_type="NUMERIC",
        value=correctness,
        comment=generation,  # reasoning
    )

    # make list of retrieved docs
    # add '+' prefix if the doc was the correct one
    # add no prefix if it was incorrect
    # if no correct docs identfied, add the correct one at the end with '-' prefix 
    doc_selection_display = ""
    for doc in retrieved_doc_ids:
        if doc == expected_doc_id:
            doc_selection_display += '+' + doc
        else:
            doc_selection_display += doc
        doc_selection_display += ", "
    if expected_doc_id not in retrieved_doc_ids:
        doc_selection_display = '-' + doc   
    
    handler.trace.score(
        name="retrieval_precision",
        data_type="NUMERIC",
        value=retrieval_precision,
        comment=doc_selection_display,
    )

    handler.trace.score(
        name="retrieval_recall",
        data_type="BOOLEAN",
        value=retrieval_recall,
        comment=doc_selection_display,
    )

    answer_correctness_scores.append(correctness)
    retrieval_precision_scores.append(retrieval_precision)
    retrieval_recall_scores.append(retrieval_recall)    

# Print the final average score in a formatted way
mean_correctness_score = sum(answer_correctness_scores) / len(answer_correctness_scores)
mean_retrieval_precision_score = sum(retrieval_precision_scores) / len(retrieval_precision_scores)
mean_retrieval_recall_score = sum(retrieval_recall_scores) / len(retrieval_recall_scores)

print(f"{'='*50}")
print(f"\n{'='*50}\nAverage Correctness: {mean_correctness_score:.2f}")
print(f"Mean Retrieval Precision: {mean_retrieval_precision_score:.2f}")
print(f"Mean Retrieval Recall: {mean_retrieval_recall_score:.2f}")
print(f"{'='*50}")

# Flush the langfuse client to ensure all data is sent to the server at the end of the experiment run
langfuse.flush()


  0%|          | 0/124 [00:00<?, ?it/s]

Input: what was the difference in percentage cumulative total shareowners 2019 returns for united parcel service inc . versus the standard & poor 2019s 500 index for the five years ended 12/31/10?
Expected Document: Single_UPS/2010/page_33.pdf-4
Retrieved Documents: ['Single_UPS/2010/page_33.pdf-4']
Retrieval Precision: 1.0
Retrieval Recall: 1.0
Generation:  To determine the 5-year return, please notice that each of these lines contains
--------------------------------------------------
Input: what portion of the total debt and capital lease obligations is included in the section of current liabilities in 2011?
Expected Document: Double_ADBE/2011/page_116.pdf
Retrieved Documents: ['Double_ADBE/2011/page_116.pdf']
Retrieval Precision: 1.0
Retrieval Recall: 1.0
Generation:  We need to answer the following question using the context provided by the documents.
 


--------------------------------------------------
Input: as of december 2 , 2011 , what would capital lease obligations be in 

In [4]:
output

{'messages': [HumanMessage(content='what was the percentage change in the net cash from operating activities from 2008 to 2009', id='98110015-dfaa-44da-b4a9-3b2a4187c166'),
  AIMessage(content=' In order to answer this question I need to follow these steps: \n    ', id='d76cd115-2cdb-471d-937a-eb5579d8a808')],
 'steps': ['extract_question', 'rerank'],
 'question': 'what was the percentage change in the net cash from operating activities from 2008 to 2009',
 'documents': [Document(id='Single_JKHY/2009/page_28.pdf-3', metadata={'id': 'Single_JKHY/2009/page_28.pdf-3', 'qa': "{'question': 'what was the percentage change in the net cash from operating activities from 2008 to 2009', 'answer': '14.1%', 'explanation': '', 'ann_table_rows': [6], 'ann_text_rows': [], 'steps': [{'op': 'minus2-1', 'arg1': '206588', 'arg2': '181001', 'res': '25587'}, {'op': 'divide2-2', 'arg1': '#0', 'arg2': '181001', 'res': '14.1%'}], 'program': 'subtract(206588, 181001), divide(#0, 181001)', 'gold_inds': {'table_