In [None]:
import pandas as pd
from trulens.benchmark.benchmark_frameworks.experiments.dataset_preprocessing import (
    generate_trec_dl_passage_benchmark,
)

trec_2021_samples = list(
    generate_trec_dl_passage_benchmark(
        max_samples_per_query_per_score=4,
        dataset_path="msmarco-passage-v2/trec-dl-2021/judged",
    )
)
trec_2022_samples = list(
    generate_trec_dl_passage_benchmark(
        max_samples_per_query_per_score=4,
        dataset_path="msmarco-passage-v2/trec-dl-2022/judged",
    )
)
trec_combined = trec_2021_samples + trec_2022_samples

trec_combined_df = pd.DataFrame(trec_combined)

print(f"Totoal number of samples: {len(trec_combined_df)}")

In [None]:
from trulens.benchmark.benchmark_frameworks.experiments.dataset_preprocessing import (
    visualize_expected_score_distribution,
)

trec_combined_relevance_scores = [
    entry["expected_score"] for _, entry in trec_combined_df.iterrows()
]
visualize_expected_score_distribution(trec_combined_relevance_scores)

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split


def balanced_split(
    df,
    label_column="expected_score",
    train_size=0.6,
    dev_size=0.2,
    test_size=0.2,
    random_state=42,
):
    """
    Splits a DataFrame into train, dev, and test sets with balanced labels.

    Args:
        df (pd.DataFrame): The input DataFrame.
        label_column (str): The column containing the labels to balance on.
        train_size (float): Proportion of the data to use for training.
        dev_size (float): Proportion of the data to use for dev/validation.
        test_size (float): Proportion of the data to use for testing.
        random_state (int): Random seed for reproducibility.

    Returns:
        train_df (pd.DataFrame): Training split.
        dev_df (pd.DataFrame): Development/validation split.
        test_df (pd.DataFrame): Testing split.
    """
    assert (
        abs(train_size + dev_size + test_size - 1.0) < 1e-5
    ), "Sizes must sum to 1.0"

    # Step 1: Split train+dev and test
    train_dev_df, test_df = train_test_split(
        df,
        test_size=test_size,
        stratify=df[label_column],
        random_state=random_state,
    )

    # Step 2: Calculate relative size for dev split from train+dev
    dev_relative_size = dev_size / (train_size + dev_size)

    # Step 3: Split train and dev
    train_df, dev_df = train_test_split(
        train_dev_df,
        test_size=dev_relative_size,
        stratify=train_dev_df[label_column],
        random_state=random_state,
    )

    return train_df, dev_df, test_df


train_df, dev_df, test_df = balanced_split(
    trec_combined_df, train_size=0.6, dev_size=0.2, test_size=0.2
)

print(f"Train size: {len(train_df)}")
print(f"Dev size: {len(dev_df)}")
print(f"Test size: {len(test_df)}")

In [None]:
train_scores = [entry["expected_score"] for _, entry in train_df.iterrows()]
visualize_expected_score_distribution(train_scores)

dev_scores = [entry["expected_score"] for _, entry in dev_df.iterrows()]
visualize_expected_score_distribution(dev_scores)

test_scores = [entry["expected_score"] for _, entry in test_df.iterrows()]
visualize_expected_score_distribution(test_scores)

### Implement TruLens' `context_relevance_with_cot_reasons` in AdalFlow

In [None]:
# print(ContextRelevance.system_prompt)

cortex_search_custom_prompt = """You are an expert search result rater. You are given a user query and a search result. Your task is to rate the search result based on its relevance to the user query. You should rate the search result on a scale of 0 to 3, where:
    0: The search result has no relevance to the user query.
    1: The search result has low relevance to the user query. In this case the search result may contain some information which seems very slightly related to the user query but not enough information to answer the user query. The search result contains some references or very limited information about some entities present in the user query. In case the query is a statement on a topic, the search result should be tangentially related to it.
    2: The search result has medium relevance to the user query. If the user query is a question, the search result may contain some information that is relevant to the user query but not enough information to answer the user query. If the user query is a search phrase/sentence, either the search result is centered around about most but not all entities present in the user query, or if all the entities are present in the result, the search result while not being centered around it has medium level of relevance. In case the query is a statement on a topic, the search result should be related to the topic.
    3: The search result has high relevance to the user query. If the user query is a question, the search result contains information that can answer the user query. Otherwise if the search query is a search phrase/sentence, it provides relevant information about all entities that are present in the user query and the search result is centered around the entities mentioned in the query. In case the query is a statement on a topic, the search result should be either be directly addressing it or be on the same topic.
    
    You should think step by step about the user query and the search result and rate the search result. You should also provide a reasoning for your rating.
    
    Use the following format:
    Rating: Example Rating
    Reasoning: Example Reasoning
    
    Now given the user query and search result below, rate the search result based on its relevance to the user query and provide a reasoning for your rating.
"""

In [None]:
CORTEX_FEW_SHOT_DEMOS = """
### Examples
    Example:
    Example 1:
    INPUT:
    User Query: What is the definition of an accordion?
    Search Result: Accordion definition, Also called piano accordion. a portable wind instrument having a large bellows for forcing air through small metal reeds, a keyboard for the right hand, and buttons for sounding single bass notes or chords for the left hand. a similar instrument having single-note buttons instead of a keyboard.
    OUTPUT:
    Rating: 3
    Reasoning: In this case the search query is a question. The search result directly answers the user question for the definition of an accordion, hence it has high relevance to the user query.
    
    Example 2:
    INPUT:
    User Query: dark horse
    Search Result: Darkhorse is a person who everyone expects to be last in a race. Think of it this way. The person who looks like he can never get laid defies the odds and gets any girl he can by being sly,shy and cunning. Although he\'s not a player, he can really charm the ladies.
    OUTPUT:
    Rating: 3
    Reasoning: In this case the search query is a search phrase mentioning \'dark horse\'. The search result contains information about the term \'dark horse\' and provides a definition for it and is centered around it. Hence it has high relevance to the user query.
    
    Example 3:
    INPUT:
    User Query: Global warming and polar bears
    Search Result: Polar bear The polar bear is a carnivorous bear whose native range lies largely within the Arctic Circle, encompassing the Arctic Ocean, its surrounding seas and surrounding land masses. It is a large bear, approximately the same size as the omnivorous Kodiak bear (Ursus arctos middendorffi).
    OUTPUT:
    Rating: 2
    Reasoning: In this case the search query is a search phrase mentioning two entities \'Global warming\' and \'polar bears\'. The search result contains is centered around the polar bear which is one of the two entities in the search query. Therefore it addresses most of the entities present and hence has medium relevance. 
    
    Example 4:
    INPUT:
    User Query: Snowflake synapse private link
    Search Result: "This site can\'t be reached" error when connecting to Snowflake via Private Connectivity\nThis KB article addresses an issue that prevents connections to Snowflake failing with: "This site can\'t be reached" ISSUE: Attempting to reach Snowflake via Private Connectivity fails with the "This site can\'t be reached" error
    OUTPUT:
    Rating: 1
    Reasoning: In this case the search result is a search query mentioning \'Snowflake synapse private link\'. However the search result doesn\'t contain information about it. However it shows an error message for a generic private link which is tangentially related to the query, since snowflake synapse private link is a type of private link. Hence it has low relevance to the user query.
    
    Example 5:
    INPUT:
    User Query: The Punisher is American.
    Search Result: The Rev(Samuel Smith) is a fictional character, a supervillain appearing in American comic books published by Marvel Comics. Created by Mike Baron and Klaus Janson, the character made his first appearance in The Punisher Vol. 2, #4 (November 1987). He is an enemy of the Punisher.
    OUTPUT:
    Rating: 1
    Reasoning: In this case the search query is a statement concerning the Punisher. However the search result is about a character called Rev, who is an enemy of the Punisher. The search result is tangentially related to the user query but does not address topic about Punisher being an American. Hence it has low relevance to the user query.

    Example 6:
    INPUT:
    User Query: query_history
    Search Result: The function task_history() is not enough for the purposes when the required result set is more than 10k.If we perform UNION between information_schema and account_usage , then we will get more than 10k records along with recent records as from information_schema.query_history to snowflake.account_usage.query_history is 45 mins behind.
    OUTPUT:
    Rating: 1
    Reasoning: In this case the search query mentioning one entity \'query_history\'. The search result is neither centered around it and neither has medium relevance, it only contains an unimportant reference to it. Hence it has low relevance to the user query.
    
    Example 7:
    INPUT:
    User Query: Who directed pulp fiction?
    Search Result: Life on Earth first appeared as early as 4.28 billion years ago, soon after ocean formation 4.41 billion years ago, and not long after the formation of the Earth 4.54 billion years ago.
    OUTPUT:
    Rating: 0
    Reasoning: In the case the search query is a question. However the search result does is completely unrelated to it. Hence the search result is completely irrelevant to the movie pulp fiction. 
    ###"""

In [None]:
import re
from typing import Dict, Optional, Tuple, Union
import warnings

import adalflow as adal
from adalflow.optim.types import ParameterType
from trulens.feedback import generated as feedback_generated

few_shot_template = r"""<START_OF_SYSTEM_PROMPT>
{{system_prompt}}
{# Few shot demos #}
{% if few_shot_demos is not none %}
Here are some examples:
{{few_shot_demos}}
{% endif %}
<END_OF_SYSTEM_PROMPT>
<START_OF_USER>
{{user_prompt}}
<END_OF_USER>
"""


class ContextRelevanceTaskPipeline(adal.Component):
    def __init__(
        self,
        target_prompt: str,
        model_client: adal.ModelClient,
        model_kwargs: Dict,
    ):
        super().__init__()

        system_prompt = adal.Parameter(
            data=target_prompt,
            role_desc="To give task instruction to the language model in the system prompt",
            requires_opt=True,
            param_type=ParameterType.PROMPT,
            instruction_to_optimizer="You can try to show examples to see if it helps. Make sure the model is being very critical and strict with its ratings, avoid being lenient / false positves to get better results.",
        )
        few_shot_demos = adal.Parameter(
            data=CORTEX_FEW_SHOT_DEMOS,
            role_desc="To provide few shot demos to the language model",
            requires_opt=False,  # Changed to True for few-shot learning
            param_type=ParameterType.DEMOS,
        )

        self.evaluate_relevance = adal.Generator(
            model_client=model_client,
            model_kwargs=model_kwargs,
            template=few_shot_template,
            prompt_kwargs={
                "system_prompt": system_prompt,
                "few_shot_demos": few_shot_demos,
            },
            use_cache=True,
            output_processors=self.parse_output,
        )

    @adal.fun_to_component
    def parse_output(response: str):
        # Extract the rating and reasoning from the output
        response = response.strip()

        rating = None
        reasoning = None
        try:
            for line in response.split("\n"):
                if line.startswith("Rating:"):
                    rating = int(line.split(":")[1].strip())
                elif line.startswith("Reasoning:"):
                    reasoning = line.split(":")[1].strip()
        except Exception as e:
            print(f"Error parsing response: {e}")

        return rating, reasoning

    @adal.fun_to_component
    def parse_trulens_output(response: str) -> Tuple[float, Dict]:
        score, reason = None, None
        if response and "Supporting Evidence" in response:
            score = -1
            supporting_evidence = None
            criteria = None
            for line in response.split("\n"):
                if "Score" in line:
                    score = (
                        feedback_generated.re_configured_rating(
                            line,
                            min_score_val=0,
                            max_score_val=3,
                        )
                    ) / 3
                criteria_lines = []
                supporting_evidence_lines = []
                collecting_criteria = False
                collecting_evidence = False

                for line in response.split("\n"):
                    if "Criteria:" in line:
                        criteria_lines.append(
                            line.split("Criteria:", 1)[1].strip()
                        )
                        collecting_criteria = True
                        collecting_evidence = False
                    elif "Supporting Evidence:" in line:
                        supporting_evidence_lines.append(
                            line.split("Supporting Evidence:", 1)[1].strip()
                        )
                        collecting_evidence = True
                        collecting_criteria = False
                    elif collecting_criteria:
                        if "Supporting Evidence:" not in line:
                            criteria_lines.append(line.strip())
                        else:
                            collecting_criteria = False
                    elif collecting_evidence:
                        if "Criteria:" not in line:
                            supporting_evidence_lines.append(line.strip())
                        else:
                            collecting_evidence = False

                criteria = "\n".join(criteria_lines).strip()
                supporting_evidence = "\n".join(
                    supporting_evidence_lines
                ).strip()
            reason = {
                "reason": (
                    f"{'Criteria: ' + str(criteria)}\n"
                    f"{'Supporting Evidence: ' + str(supporting_evidence)}"
                )
            }
            score = score
            reason = reason

        else:
            if not response:
                score = 0
                reason = {"reason": "No response generated."}
            else:
                score = (
                    feedback_generated.re_configured_rating(
                        response,
                        min_score_val=0,
                        max_score_val=3,
                    )
                ) / 3
                warnings.warn(
                    "No supporting evidence provided. Returning score only.",
                    UserWarning,
                )
                score = score
                reason = {}

        score_pattern = re.compile(r"Score:\s*([0-9.]+)")
        match = score_pattern.search(reason.get("reason", ""))
        normalized_reason = None
        if match:
            original_reason_score = float(match.group(1))
            normalized_reason_score = (original_reason_score) / 3

            # Ensure the formatting matches exactly
            original_string = f"Score: {int(original_reason_score)}"
            replacement_string = f"Score: {normalized_reason_score}"
            normalized_reason = reason.copy()
            normalized_reason["reason"] = normalized_reason["reason"].replace(
                original_string, replacement_string
            )

        if normalized_reason is not None:
            return score, normalized_reason
        else:
            return score, reason

    def call(
        self, query: str, context: str, id: Optional[str] = None
    ) -> Union[adal.GeneratorOutput, adal.Parameter]:
        user_prompt = """INPUT:
                        User Query: {query}
                        Search Result: {context}
                        OUTPUT:\n""".format(query=query, context=context)

        return self.evaluate_relevance(
            prompt_kwargs={"user_prompt": user_prompt}, id=id
        )

    # def call(
    #     self,
    #     query: str,
    #     context: str,
    #     id: Optional[str] = None,
    # ) -> Union[adal.GeneratorOutput, adal.Parameter]:

    #     user_prompt = """QUESTION: {query}
    #     CONTEXT: {context}

    #     RELEVANCE:
    #     """.format(query=query, context=context).replace(
    #         "RELEVANCE:", feedback_prompts.COT_REASONS_TEMPLATE
    #     )

    #     return self.evaluate_relevance(
    #         prompt_kwargs={"user_prompt": user_prompt}, id=id
    #     )

In [None]:
import os

from adalflow.components.model_client.openai_client import AzureOpenAIClient

az_gpt_4o_model = {
    "model_client": AzureOpenAIClient(),
    "model_kwargs": {
        "model": os.environ["AZURE_OPENAI_DEPLOYMENT"],
        "max_tokens": 4000,
        "temperature": 0.0,
        "top_p": 0.99,
        "frequency_penalty": 0,
        "presence_penalty": 0,
        "stop": None,
    },
}

# gpt_4o_model = {
#     "model_client": OpenAIClient(),
#     "model_kwargs": {
#         "model": "gpt-4o",
#         "max_tokens": 4000,
#         "temperature": 0.0,
#         "top_p": 0.99,
#         "frequency_penalty": 0,
#         "presence_penalty": 0,
#         "stop": None,
#     },
# }

task_pipeline = ContextRelevanceTaskPipeline(
    target_prompt=cortex_search_custom_prompt, **az_gpt_4o_model
)
print(task_pipeline)

output = task_pipeline(
    query="Is apple safe to eat?", context="All fruits are edible"
)
output

In [None]:
task_pipeline.train()  # set to train mode

In [None]:
task_pipeline.train()  # set to train mode

In [None]:
train_df

In [None]:
from dataclasses import dataclass
from dataclasses import field
import uuid

from adalflow.datasets.types import Example


@dataclass
class TrecDLData(Example):
    __doc__ = """A dataclass for representing examples in the TREC DL (passage retrieval) dataset."""

    id: str = field(
        metadata={"desc": "The unique identifier of the example", "type": "id"},
        default_factory=lambda: str(
            uuid.uuid4()
        ),  # Ensures a unique UUID for each instance
    )
    query: Optional[str] = field(
        metadata={"desc": "The query from user."},
        default=None,
    )

    expected_response: Optional[str] = field(
        metadata={"desc": "The retrieved context for the query."},
        default=None,
    )

    expected_score: Optional[float] = field(
        metadata={"desc": "The expected relevance score for the answer."},
        default=None,
    )


train_dataset = [
    TrecDLData(
        query=row["query"],
        expected_response=row["expected_response"],
        expected_score=row["expected_score"] * 3,
    )
    for _, row in train_df.iterrows()
]
val_dataset = [
    TrecDLData(
        query=row["query"],
        expected_response=row["expected_response"],
        expected_score=row["expected_score"] * 3,
    )
    for _, row in dev_df.iterrows()
]
test_dataset = [
    TrecDLData(
        query=row["query"],
        expected_response=row["expected_response"],
        expected_score=row["expected_score"] * 3,
    )
    for _, row in test_df.iterrows()
]


def context_relevance_eval_fn(y: float, y_gt: float) -> float:
    return 1.0 if y == y_gt else 0.0


def weighted_relevance_loss(
    y: float, y_gt: float, false_positive_weight
) -> float:
    """
    Penalizes false positives more heavily and keeps the loss in [0, 1].
    """
    # Identify the type of error
    if y > y_gt:  # False positive
        penalty = false_positive_weight
    elif y != y_gt:  # Other mismatches (false negatives)
        penalty = 1.0
    else:  # Correct predictions
        return 0.0

    # Normalize the penalty to keep the loss in [0, 1]
    normalized_loss = penalty / (false_positive_weight + 1.0)

    return (
        1 - normalized_loss
    )  # textual loss higher the better (UNLIKE typical ML loss)


class ContextRelevanceAdalComponentOnTrecDL(adal.AdalComponent):
    def __init__(
        self,
        target_prompt: str,
        model_client: adal.ModelClient,
        model_kwargs: Dict,
        backward_engine_model_config: Dict = None,
        teacher_model_config: Dict = None,
        text_optimizer_model_config: Dict = None,
    ):
        task = ContextRelevanceTaskPipeline(
            target_prompt, model_client, model_kwargs
        )
        # eval_fn = AnswerMatchAcc(type="exact_match").compute_single_item
        eval_fn = context_relevance_eval_fn
        loss_fn = adal.EvalFnToTextLoss(
            eval_fn=lambda y, y_gt: weighted_relevance_loss(
                y, y_gt, false_positive_weight=3.0
            ),
            eval_fn_desc="Give a lower score when the model gives higher rating than ground truth to avoid being too lenient (y higher than y_gt)",
        )

        # eval_fn = context_relevance_eval_fn
        # loss_fn = adal.EvalFnToTextLoss(
        #     eval_fn=eval_fn,
        #     eval_fn_desc="""Binarized / label unification  - y and y_gt are both in [0, 1], y_binary = 1 if y >= 0.5 else 0; y_gt_binary = 1 if y_gt >= 0.5 else 0;
        #     so 1 if y_binary == y_gt_binary else 0 (after label unification)""",
        # )

        super().__init__(task=task, eval_fn=eval_fn, loss_fn=loss_fn)
        self.backward_engine_model_config = backward_engine_model_config
        self.teacher_model_config = teacher_model_config
        self.text_optimizer_model_config = text_optimizer_model_config

    def prepare_task(self, sample: TrecDLData):
        return self.task.call, {
            "query": sample.query,
            "context": sample.expected_response,
            "id": sample.id,
        }

    def prepare_loss(self, sample: TrecDLData, pred: adal.Parameter):
        # prepare the gt and pred for the loss function
        y_gt = adal.Parameter(
            name="y_gt",
            data=sample.expected_score,
            eval_input=sample.expected_score,
            requires_opt=False,
        )

        # print(f"pred: {pred}")
        # print(f"pred.full_response: {pred.full_response}")

        pred.eval_input = (
            pred.full_response.data[0]
            if pred and pred.full_response and len(pred.full_response.data) > 0
            else 0
        )

        return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}}

    def prepare_eval(self, sample: TrecDLData, y_pred: adal.GeneratorOutput):
        y_label = -1
        if (
            y_pred
            and y_pred.data
            and len(y_pred.data) > 0
            and isinstance(y_pred.data[0], (int, float))
        ):
            y_label = y_pred.data[0]

        # print(y_pred, y_label, sample.expected_score)
        return self.eval_fn, {"y": y_label, "y_gt": sample.expected_score}

    def configure_backward_engine(self):
        super().configure_backward_engine_helper(
            **self.backward_engine_model_config
        )

    def configure_teacher_generator(self):
        super().configure_teacher_generator_helper(**self.teacher_model_config)

    def configure_optimizers(self):
        to = super().configure_text_optimizer_helper(
            **self.text_optimizer_model_config
        )
        do = super().configure_demo_optimizer_helper()  # Add demo optimizer
        return to + do  # Return both text and demo optimizers

In [None]:
def diagnose(
    model_client: adal.ModelClient,
    model_kwargs: Dict,
) -> Dict:
    trainset, valset, testset = (
        train_dataset,
        val_dataset,
        test_dataset,
    )
    # use max_samples=10 to test the code

    adal_component = ContextRelevanceAdalComponentOnTrecDL(
        cortex_search_custom_prompt, model_client, model_kwargs
    )
    trainer = adal.Trainer(adaltask=adal_component)
    trainer.diagnose(dataset=trainset, split="train")
    trainer.diagnose(dataset=valset, split="val")
    trainer.diagnose(dataset=testset, split="test")


diagnose(**az_gpt_4o_model)

In [None]:
def train(
    train_batch_size=4,  # larger batch size is not that effective, probably because of llm's lost in the middle
    raw_shots: int = 0,
    bootstrap_shots: int = 1,
    max_steps=1,
    num_workers=4,
    strategy="random",
    optimization_order="sequential",
    debug=False,
    resume_from_ckpt=None,
    exclude_input_fields_from_bootstrap_demos=False,
):
    adal_component = ContextRelevanceAdalComponentOnTrecDL(
        target_prompt=cortex_search_custom_prompt,
        **az_gpt_4o_model,
        teacher_model_config=az_gpt_4o_model,
        text_optimizer_model_config=az_gpt_4o_model,
        backward_engine_model_config=az_gpt_4o_model,
    )
    print(adal_component)
    trainer = adal.Trainer(
        train_batch_size=train_batch_size,
        adaltask=adal_component,
        strategy=strategy,
        max_steps=max_steps,
        num_workers=num_workers,
        raw_shots=raw_shots,
        bootstrap_shots=bootstrap_shots,
        debug=debug,
        weighted_sampling=True,
        optimization_order=optimization_order,
        exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos,
    )
    print(trainer)

    # train_dataset, val_dataset, test_dataset = load_datasets()
    trainer.fit(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        debug=debug,
        resume_from_ckpt=resume_from_ckpt,
    )

In [None]:
train(
    train_batch_size=4,
    debug=False,
    max_steps=15,
    strategy="constrained",
    raw_shots=1,
    bootstrap_shots=1,
    exclude_input_fields_from_bootstrap_demos=True,
)