In [None]:
from trulens.core import TruSession

session = TruSession()
session.reset_database()

### Datasets preprocessing:

Datasets that need some preprocessing before they can be used in `TruBenchmarkExperiment` class:
1. Snowflake IT (internal): both rephrased and regular?, this should be used for all 3 in the triad
2. SummEval (CNN and DailyMail summarizations with annotation) for groundedness
3. QAGS (CNN and DailyMail with Turkers' annotation) for groundedness
4. QAGS (XSUM with Turkers' annotation) for groundedness
5. MSMARCO V2 for context relevance
6. HotPot QA for answer relevance 



In [None]:
import ast
import csv
import json


# SummEval
def generate_summeval_groundedness_golden_set(file_path):
    def calculate_expected_score(normalized_metrics_lst, weights_lst):
        assert len(normalized_metrics_lst) == len(weights_lst)
        return round(
            sum(
                normalized_metrics_lst[i] * weights_lst[i]
                for i in range(len(normalized_metrics_lst))
            )
            / sum(weights_lst),
            2,
        )

    with open(file_path) as f:
        for line in f:
            # Each line is a separate JSON object
            try:
                data = json.loads(line)
                print(data)  # For debugging, see the data structure

                # Ensure the expected keys exist in the JSON
                try:
                    row = data
                    assert (
                        len(row["machine_summaries"]) == len(row["consistency"])
                    ), "Mismatch in lengths of machine_summaries and consistency"

                    # Iterate over the summaries and create the desired dictionary structure
                    for i in range(len(row["machine_summaries"])):
                        yield {
                            "query": row.get(
                                "text", ""
                            ),  # Default to empty string if key not found
                            "expected_response": row["machine_summaries"][i],
                            "expected_score": calculate_expected_score(
                                [
                                    (row["consistency"][i] - 1)
                                    / 4,  # Normalize from [1, 5] to [0, 1]
                                ],
                                [1.0],
                            ),
                            "human_score": row["consistency"][i],
                        }

                except KeyError as e:
                    print(
                        f"Key error: {e}. Please check if the keys exist in the JSON file."
                    )
                except AssertionError as e:
                    print(
                        f"Assertion error: {e}. The lengths of 'machine_summaries' and 'consistency' do not match."
                    )

            except json.JSONDecodeError as e:
                print(f"JSON decode error: {e}. Check the line format.")


# Snowflake IT dataset
def generate_snowflake_it_golden_set_answer_relevance(file_path):
    with open(file_path, mode="r", encoding="utf-8") as f:
        reader = csv.DictReader(f)  # Read the CSV file as a dictionary
        for row in reader:
            # Extract data and yield the required fields
            yield {
                "query": row["query"],
                "expected_response": row["expected_response"],
                "expected_score": 1,  # always positive example for answer relevance
            }


def generate_snowflake_it_golden_set_context_relevance(file_path):
    with open(file_path, mode="r", encoding="utf-8") as f:
        reader = csv.DictReader(f)  # Read the CSV file as a dictionary
        for row in reader:
            # Convert the 'golden' from a string to a list
            try:
                expected_chunks = ast.literal_eval(row["golden"])
                if len(expected_chunks) > 1:
                    print(
                        f'query w/ more than one golden contexts: {row["query"]}'
                    )
                    print(expected_chunks)
                if not isinstance(expected_chunks, list):
                    raise ValueError("Golden column should be a list")
            except (ValueError, SyntaxError) as e:
                print(f"Error parsing golden column: {e}")
                continue

            # Yield the required fields
            yield {
                "query": row["query"],
                "expected_response": row["expected_response"],
                "expected_score": 1,  # Static score as per the requirement
            }


def generate_qags_golden_set_groundedness(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        for line in file:
            # Parse each line as a JSON object
            data = json.loads(line)

            # Extract the article as the query
            query = data["article"]

            # Iterate over the summary_sentences to flatten the structure
            for summary in data["summary_sentences"]:
                expected_response = summary["sentence"]

                # Calculate expected_score based on worker responses
                responses = [
                    response["response"] for response in summary["responses"]
                ]
                # Convert 'yes' to 1 and 'no' to 0, then calculate the average
                expected_score = sum(
                    1 if r.lower() == "yes" else 0 for r in responses
                ) / len(responses)

                # Yield the processed record
                yield {
                    "query": query,
                    "expected_response": expected_response,
                    "expected_score": expected_score,
                }

In [None]:
# from datasets import load_dataset

# Load a dataset from the Hugging Face Hub
# dataset = load_dataset('mteb/summeval', split='test')
# dataset.to_json('summeval_testss.json')

In [None]:
list(
    generate_snowflake_it_golden_set_context_relevance(
        "/Users/dhuang/Documents/git/trulens/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/snowflake_it_v3.csv"
    )
)

In [None]:
list(
    generate_qags_golden_set_groundedness(
        "/Users/dhuang/Documents/git/trulens/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/qags_mturk_cnndm.jsonl"
    )
)

In [None]:
list(
    generate_qags_golden_set_groundedness(
        "/Users/dhuang/Documents/git/trulens/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/qags_mturk_xsum.jsonl"
    )
)