In [None]:
# ! pip install dspy-ai trulens-core trulens-providers-openai ragas

In [None]:
import os

os.environ["OPENAI_API_KEY"]

In [None]:
import dspy

# gpt_4o_mini = dspy.LM('openai/gpt-4o-mini')
local_llama_3 = dspy.OllamaLocal(model="llama3.1:8b")

# start with local model first
dspy.configure(lm=local_llama_3)

In [None]:
local_llama_3("hi there how are you?")

### Prepare datasets for DSPy
we'd be using XSum (EXtreme Summarization) as the training, dev, and test sets


In [None]:
from dspy.datasets.dataset import Dataset
import pandas as pd
from trulens.benchmark.benchmark_frameworks.experiments.dataset_preprocessing import (
    generate_qags_golden_set_groundedness,
)

# entire dataset
xsum_df = pd.DataFrame(
    list(
        generate_qags_golden_set_groundedness(
            "../../src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/qags_mturk_xsum.jsonl",
            max_samples_per_bucket=100,
        )
    )
)


class XSumDataset(Dataset):
    def __init__(self, df, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self._train = df.iloc[0 : int(len(df) * 0.8)].to_dict(orient="records")

        self._dev = df.iloc[int(len(df) * 0.8) : int(len(df) * 0.9)].to_dict(
            orient="records"
        )

        self._test = df.iloc[int(len(df) * 0.9) :].to_dict(orient="records")


dataset = XSumDataset(xsum_df, input_keys=["query", "expected_response"])
print(dataset.train[:3])


data_train = []
data_dev = []
data_test = []
for example in dataset.train:
    data_train.append({
        "query": example.query,
        "expected_score": example.expected_score,
        "expected_response": example.expected_response,
    })


for example in dataset.dev:
    data_dev.append({
        "query": example.query,
        "expected_score": example.expected_score,
        "expected_response": example.expected_response,
    })

for example in dataset.test:
    data_test.append({
        "query": example.query,
        "expected_score": example.expected_score,
        "expected_response": example.expected_response,
    })
df_train = pd.DataFrame(data_train)

df_dev = pd.DataFrame(data_dev)

df_test = pd.DataFrame(data_test)

In [None]:
df_dev["expected_score"]

In [None]:
import matplotlib.pyplot as plt

df_dev["expected_score"].hist(bins=10)
plt.xlabel("Expected Score")
plt.ylabel("Frequency")
plt.title("Distribution of Expected Scores")
plt.show()

### Define metric for DSPy pipeline

In [None]:
from typing import Union


def evaluate_groundedness_score(
    example, pred, trace=None, use_binary_threshold=True, threshold=0.5
) -> Union[float, bool]:
    gt_score = example.expected_score
    pred_score = pred.output.score

    if trace is None:
        if use_binary_threshold:
            gt_label = 1 if gt_score >= threshold else 0

            pred_label = 1 if pred_score >= threshold else 0

            return 1.0 if gt_label == pred_label else 0.0
        else:
            return 1.0 - abs(gt_score - pred_score)
    else:
        binary_gt_label = 1 if gt_score >= threshold else 0
        binary_pred_label = 1 if pred_score >= threshold else 0

        #  if we're doing bootstrapping, i.e. self-generating good demonstrations of each step
        return binary_gt_label == binary_pred_label

## Build our DSPy pipline with signatures, Assertions, CoT, metric, and dataset to implement TruLens groundedness

In [None]:
from dspy.primitives import Prediction
from dspy.teleprompt import MIPROv2
from pydantic import BaseModel
from pydantic import Field
from trulens.providers.litellm import LiteLLM


class Input(BaseModel):
    source: str = Field(
        description="Source context from the retrieved documents"
    )
    statement: str = Field(
        description="The generated response to the query that its groundedness shall be evaluated."
    )


class Output(BaseModel):
    reasons: str = Field(
        description="The reasons for why the groundedness score is given."
    )
    score: float = Field(
        ge=0, le=1, description="The groundedness score for the answer"
    )


class GroundednessSignature(dspy.Signature):
    """Your task is to evaluate if every sentence in the statement (except the trivial ones like stylistic sentences) is supported or entailed by the source context.
    Generate a score the scale of 0.0 to 1.0 with reasons for the score."""

    input: Input = dspy.InputField()
    output: Output = dspy.OutputField()


# class OriginalStatementsInput(BaseModel):
#     original_statements: List[str] = Field(description="Original statements that need to be refined by removing trivial claims.")

# class RefinedStatementsOutput(BaseModel):
#     refined_statements: List[str] = Field(description="Refined statements after removing trivial claims.")

# class TrivialStatementRemovalSignature(dspy.Signature):
#     """You are a TRIVIAL STATEMENT REMOVAL classifier; providing the refined response by removing the trivial claims from the original response.
#     Consider the following list of statements. Identify and remove sentences that are stylistic, contain trivial pleasantries, or lack substantive information relevant to the main content. Respond only with a list of the remaining statements in the format of a python list of strings.
#     """
#     input: OriginalStatementsInput = dspy.InputField()
#     output: RefinedStatementsOutput = dspy.OutputField()


trulens_ollama_provider = LiteLLM(
    model_engine="ollama/llama3.1:8b", api_base="http://localhost:11434"
)


def tru_groundedness(source, statement):
    tru_res = trulens_ollama_provider.groundedness_measure_with_cot_reasons(
        source=source, statement=statement
    )

    score = tru_res[0]

    if pd.isna(score):
        score = 0.0
    reasons = tru_res[1]["reasons"]

    return Prediction(
        reasoning=reasons, output=Output(reasons=reasons, score=score)
    )


class GroundednessDSPy(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_score_and_reasons = dspy.TypedChainOfThought(
            GroundednessSignature
        )

    def forward(self, query: str, expected_response: str):
        input_pair = Input(source=query, statement=expected_response)

        try:
            output = self.generate_score_and_reasons(input=input_pair)
        except Exception as e:
            print(f"Error: {e}")
            output = Prediction(
                reasoning="Error: Unable to generate reasons or score",
                output=Output(
                    reasons="Error: Unable to generate reasons or score",
                    score=0.0,
                ),
            )

        return output


trainset, devset = dataset.train, dataset.dev


groundedness_dspy = GroundednessDSPy()


# Set up the evaluator, which can be re-used in your code.
# evaluate = Evaluate(devset=devset[:], num_threads=8, display_progress=True,  display_table=True)


def evaluate_on_devset(groundedness_fn, devset):
    tp, tn, fp, fn = 0, 0, 0, 0
    dspy_metric_scores = []
    for example in devset:
        pred = groundedness_fn(example.query, example.expected_response)

        gt_label = 1 if example.expected_score >= 0.5 else 0
        pred_label = 1 if pred.output.score >= 0.5 else 0

        dspy_metric_scores.append(
            evaluate_groundedness_score(
                example, pred, use_binary_threshold=True, threshold=0.5
            )
        )

        tp += 1 if pred_label == 1 and gt_label == 1 else 0
        tn += 1 if pred_label == 0 and gt_label == 0 else 0
        fp += 1 if pred_label == 1 and gt_label == 0 else 0
        fn += 1 if pred_label == 0 and gt_label == 1 else 0

    precision = tp / (tp + fp)

    recall = tp / (tp + fn)

    f1 = 2 * (precision * recall) / (precision + recall)

    print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
    return dspy_metric_scores, precision, recall, f1


print("Evaluate on the dev set on DSPy baseline model")

evaluate_on_devset(groundedness_dspy, devset)

print("Evaluate on the dev set on TruLens baseline model")

evaluate_on_devset(tru_groundedness, devset)

### Use MIPROv2 without fewshot examples (0-shot only):

In [None]:
mipro_optimizer = MIPROv2(
    metric=evaluate_groundedness_score,
    auto="light",
    verbose=True,
)
print("Optimizing zero-shot program with MIPRO...")

zeroshot_optimized_program = mipro_optimizer.compile(
    groundedness_dspy.deepcopy(),
    trainset=trainset,
    max_bootstrapped_demos=0,  # ZERO FEW-SHOT EXAMPLES
    max_labeled_demos=0,  # ZERO FEW-SHOT EXAMPLES
    requires_permission_to_run=False,
)

# Save optimize program for future use
zeroshot_optimized_program.save("mipro_zeroshot_optimized")

# Evaluate optimized program
print("Evluate 0-shot optimized program...")
# evaluate(zeroshot_optimized_program, devset=devset[:])

In [None]:
evaluate_on_devset(zeroshot_optimized_program, devset)

In [None]:
fewshot_optimizer = MIPROv2(
    metric=evaluate_groundedness_score,
    auto="medium",
)
fewshot_optimized_program = fewshot_optimizer.compile(
    groundedness_dspy.deepcopy(),
    trainset=trainset,
    max_bootstrapped_demos=5,  # FEW-SHOT EXAMPLES
    max_labeled_demos=5,  # FEW-SHOT EXAMPLES
    requires_permission_to_run=False,
)

# Save optimize program for future use
fewshot_optimized_program.save("mipro_fewshot_optimized")

# Evaluate optimized program
print("Evluate few-shot optimized program...")

evaluate_on_devset(fewshot_optimized_program, devset)

In [None]:
from trulens.providers.litellm import LiteLLM

trulens_ollama_provider = LiteLLM(
    model_engine="ollama/llama3.1:8b", api_base="http://localhost:11434"
)


def tru_groundedness(source, statement, filter_trivial_statements=False):
    return trulens_ollama_provider.groundedness_measure_with_cot_reasons(
        source=source,
        statement=statement,
        filter_trivial_statements=filter_trivial_statements,
    )

In [None]:
gt_scores = df_train["gt_scores"] = (
    df_train["expected_score"].apply(lambda x: 1 if x >= 0.5 else 0).to_list()
)

In [None]:
predicted_scores = []
for i, row in df_train.iterrows():
    expected_score = row["expected_score"]
    groundedness_output = tru_groundedness(
        row["query"], row["expected_response"], filter_trivial_statements=True
    )

    print(groundedness_output)

    predicted_scores.append(1 if groundedness_output[0] >= 0.5 else 0)


print(len(gt_scores), len(predicted_scores))

In [None]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

precision = precision_score(gt_scores, predicted_scores)
recall = recall_score(gt_scores, predicted_scores)
f1 = f1_score(gt_scores, predicted_scores)

print(
    f"(with trivial filtering) Precision with: {precision}, Recall: {recall}, F1: {f1}"
)

In [None]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

precision = precision_score(gt_scores, predicted_scores)
recall = recall_score(gt_scores, predicted_scores)
f1 = f1_score(gt_scores, predicted_scores)

print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")

### Implement TruLens' `groundedness_with_cot_reasons` in AdalFlow

In [None]:
import re
from typing import Dict, Optional, Tuple, Union
import warnings

import adalflow as adal
from adalflow.optim.types import ParameterType
from dspy.datasets.dataset import Dataset
import nltk
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from trulens.benchmark.benchmark_frameworks.experiments.dataset_preprocessing import (
    generate_qags_golden_set_groundedness,
)
from trulens.feedback import generated as feedback_generated
from trulens.feedback.v2.feedback import Groundedness

# entire dataset
xsum_df = pd.DataFrame(
    list(
        generate_qags_golden_set_groundedness(
            "../../src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/qags_mturk_xsum.jsonl",
            max_samples_per_bucket=100,
        )
    )
)

nltk.download("punkt_tab", quiet=True)


class XSumDataset(Dataset):
    def __init__(self, df, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        # Set the random seed for reproducibility
        np.random.seed(42)

        # Shuffle the dataframe
        df = df.sample(frac=1, random_state=42).reset_index(drop=True)

        # Split into 80% train, 10% dev, 10% test
        train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
        dev_df, test_df = train_test_split(
            temp_df, test_size=0.5, random_state=42
        )

        # Store as dictionaries for easy access
        self._train = train_df.to_dict(orient="records")
        self._dev = dev_df.to_dict(orient="records")
        self._test = test_df.to_dict(orient="records")


dataset = XSumDataset(xsum_df, input_keys=["query", "expected_response"])
print(dataset.train[:3])


data_train = []
data_dev = []
data_test = []
for example in dataset.train:
    data_train.append({
        "query": example.query,
        "expected_score": example.expected_score,
        "expected_response": example.expected_response,
    })


for example in dataset.dev:
    data_dev.append({
        "query": example.query,
        "expected_score": example.expected_score,
        "expected_response": example.expected_response,
    })

for example in dataset.test:
    data_test.append({
        "query": example.query,
        "expected_score": example.expected_score,
        "expected_response": example.expected_response,
    })
df_train = pd.DataFrame(data_train)

df_dev = pd.DataFrame(data_dev)

df_test = pd.DataFrame(data_test)

print(
    f"len(df_train): {len(df_train)}; len(df_dev): {len(df_dev)}; len(df_test): {len(df_test)}"
)

few_shot_template = r"""<START_OF_SYSTEM_PROMPT>
{{system_prompt}}
{# Few shot demos #}
{% if few_shot_demos is not none %}
Here are some examples:
{{few_shot_demos}}
{% endif %}
<END_OF_SYSTEM_PROMPT>
<START_OF_USER>
{{user_prompt}}
<END_OF_USER>
"""


class GroundednessTaskPipeline(adal.Component):
    def __init__(self, model_client: adal.ModelClient, model_kwargs: Dict):
        super().__init__()

        system_prompt = adal.Parameter(
            data=Groundedness.system_prompt,
            role_desc="To give task instruction to the language model in the system prompt",
            requires_opt=True,
            param_type=ParameterType.PROMPT,
        )
        few_shot_demos = adal.Parameter(
            data=None,
            role_desc="To provide few shot demos to the language model",
            requires_opt=True,  # Changed to True for few-shot learning
            param_type=ParameterType.DEMOS,
        )

        self.evaluate_hypothesis = adal.Generator(
            model_client=model_client,
            model_kwargs=model_kwargs,
            template=few_shot_template,
            prompt_kwargs={
                "system_prompt": system_prompt,
                "few_shot_demos": few_shot_demos,
            },
            use_cache=True,
            output_processors=self.parse_single_groundedness_output,
        )

    @adal.fun_to_component
    def parse_single_groundedness_output(response: str) -> Tuple[float, Dict]:
        score, reason = None, None
        if response and "Supporting Evidence" in response:
            score = -1
            supporting_evidence = None
            criteria = None
            for line in response.split("\n"):
                if "Score" in line:
                    score = (
                        feedback_generated.re_configured_rating(
                            line,
                            min_score_val=0,
                            max_score_val=3,
                        )
                    ) / 3
                criteria_lines = []
                supporting_evidence_lines = []
                collecting_criteria = False
                collecting_evidence = False

                for line in response.split("\n"):
                    if "Criteria:" in line:
                        criteria_lines.append(
                            line.split("Criteria:", 1)[1].strip()
                        )
                        collecting_criteria = True
                        collecting_evidence = False
                    elif "Supporting Evidence:" in line:
                        supporting_evidence_lines.append(
                            line.split("Supporting Evidence:", 1)[1].strip()
                        )
                        collecting_evidence = True
                        collecting_criteria = False
                    elif collecting_criteria:
                        if "Supporting Evidence:" not in line:
                            criteria_lines.append(line.strip())
                        else:
                            collecting_criteria = False
                    elif collecting_evidence:
                        if "Criteria:" not in line:
                            supporting_evidence_lines.append(line.strip())
                        else:
                            collecting_evidence = False

                criteria = "\n".join(criteria_lines).strip()
                supporting_evidence = "\n".join(
                    supporting_evidence_lines
                ).strip()
            reason = {
                "reason": (
                    f"{'Criteria: ' + str(criteria)}\n"
                    f"{'Supporting Evidence: ' + str(supporting_evidence)}"
                )
            }
            score = score
            reason = reason

        else:
            if not response:
                score = 0
                reason = {"reason": "No response generated."}
            else:
                score = (
                    feedback_generated.re_configured_rating(
                        response,
                        min_score_val=0,
                        max_score_val=3,
                    )
                ) / 3
                warnings.warn(
                    "No supporting evidence provided. Returning score only.",
                    UserWarning,
                )
                score = score
                reason = {}

        score_pattern = re.compile(r"Score:\s*([0-9.]+)")
        match = score_pattern.search(reason.get("reason", ""))
        normalized_reason = None
        if match:
            original_reason_score = float(match.group(1))
            normalized_reason_score = (original_reason_score) / 3

            # Ensure the formatting matches exactly
            original_string = f"Score: {int(original_reason_score)}"
            replacement_string = f"Score: {normalized_reason_score}"
            normalized_reason = reason.copy()
            normalized_reason["reason"] = normalized_reason["reason"].replace(
                original_string, replacement_string
            )

        if normalized_reason is not None:
            return score, normalized_reason
        else:
            return score, reason

    def call(
        self,
        premise: str,
        hypothesis: str,
        id: Optional[str] = None,
    ) -> Union[adal.GeneratorOutput, adal.Parameter]:
        # TODO - add trivial statement prompt to be another parameter to optimize

        # def evaluate_hypothesis(index, hypothesis):
        user_prompt = """SOURCE: {premise}

        Hypothesis: {hypothesis}

        Please answer with the template below for all statement sentences:

        Criteria: <Statement Sentence>
        Supporting Evidence: <Identify and describe the location in the source where the information matches the statement. Provide a detailed, human-readable summary indicating the path or key details. if nothing matches, say NOTHING FOUND. For the case where the statement is an abstention, say ABSTENTION>
        Score: <Output a number based on the scoring output space / range>
        """.format(premise=premise, hypothesis=hypothesis)

        return self.evaluate_hypothesis(
            prompt_kwargs={"user_prompt": user_prompt}, id=id
        )

        # groundedness_scores = {}
        # reasons_str = ""
        # hypotheses = sent_tokenize(statement)
        # results = []

        # with ThreadPoolExecutor() as executor:
        #     futures = [
        #         executor.submit(evaluate_hypothesis, i, hypothesis)
        #         for i, hypothesis in enumerate(hypotheses)
        #     ]

        #     for future in as_completed(futures):
        #         results.append(future.result())

        # results.sort(key=lambda x: x[0])  # Sort results by index

        # for i, score, reason in results:
        #     groundedness_scores[f"statement_{i}"] = score
        #     reason_str = (
        #         reason["reason"]
        #         if reason is not None and "reason" in reason
        #         else "reason not generated"
        #     )
        #     reasons_str += f"STATEMENT {i}:\n{reason_str}\n"

        # # Calculate the average groundedness score from the scores dictionary
        # average_groundedness_score = float(
        #     np.mean(list(groundedness_scores.values()))
        # )

        # return average_groundedness_score, {"reasons": reasons_str}

In [None]:
from adalflow.components.model_client.ollama_client import OllamaClient
from adalflow.components.model_client.openai_client import OpenAIClient

gpt_mini_model = {
    "model_client": OpenAIClient(),
    "model_kwargs": {
        "model": "gpt-4o-mini",
        "max_tokens": 2000,
        "temperature": 0.0,
        "top_p": 0.99,
        "frequency_penalty": 0,
        "presence_penalty": 0,
        "stop": None,
    },
}

llama_3_1_model = {
    "model_client": OllamaClient(),
    "model_kwargs": {"model": "llama3.1:8b"},
}

gpt_4o_model = {
    "model_client": OpenAIClient(),
    "model_kwargs": {
        "model": "gpt-4o",
        "max_tokens": 4000,
        "temperature": 0.0,
        "top_p": 0.99,
        "frequency_penalty": 0,
        "presence_penalty": 0,
        "stop": None,
    },
}


task_pipeline = GroundednessTaskPipeline(**gpt_mini_model)
print(task_pipeline)

In [None]:
output = task_pipeline(
    premise="All fruits not edible", hypothesis=" Apple is edible"
)
output

### Start auto prompt optimization with Adalflow

In [None]:
task_pipeline.train()  # set to train mode

In [None]:
from dataclasses import dataclass
from dataclasses import field
import uuid

from adalflow.datasets.types import Example


@dataclass
class XSumData(Example):
    __doc__ = """A dataclass for representing examples in the XSum dataset."""

    id: str = field(
        metadata={"desc": "The unique identifier of the example", "type": "id"},
        default=str(uuid.uuid4()),
    )
    query: Optional[str] = field(
        metadata={"desc": "The source context from the retrieved documents."},
        default=None,
    )

    expected_response: Optional[str] = field(
        metadata={
            "desc": "The generated response to the query that its groundedness shall be evaluated."
        },
        default=None,
    )

    expected_score: Optional[float] = field(
        metadata={"desc": "The expected groundedness score for the answer."},
        default=None,
    )

    # __input_fields__ = [
    #     "id",
    #     "query",
    #     "expected_response",
    #     "expected_score"
    # ]  # follow this order too.
    # __output_fields__ = ["expected_score"]


train_dataset = [
    XSumData(
        query=row["query"],
        expected_response=row["expected_response"],
        expected_score=row["expected_score"],
    )
    for _, row in df_train.iterrows()
]
val_dataset = [
    XSumData(
        query=row["query"],
        expected_response=row["expected_response"],
        expected_score=row["expected_score"],
    )
    for _, row in df_dev.iterrows()
]
test_dataset = [
    XSumData(
        query=row["query"],
        expected_response=row["expected_response"],
        expected_score=row["expected_score"],
    )
    for _, row in df_test.iterrows()
]


def groundedness_eval_fn(y: float, y_gt: float) -> float:
    y_binary = 1 if y >= 0.5 else 0
    y_gt_binary = 1 if y_gt >= 0.5 else 0
    return 1 if y_binary == y_gt_binary else 0


def weighted_groundedness_loss(
    y: float, y_gt: float, false_positive_weight: float = 2.0
) -> float:
    """
    Penalizes false positives more heavily and keeps the loss in [0, 1].
    """
    y_binary = 1 if y >= 0.5 else 0
    y_gt_binary = 1 if y_gt >= 0.5 else 0

    # Identify the type of error
    if y_binary == 1 and y_gt_binary == 0:  # False positive
        penalty = false_positive_weight
    elif y_binary != y_gt_binary:  # Other mismatches (false negatives)
        penalty = 1.0
    else:  # Correct predictions
        return 0.0

    # Normalize the penalty to keep the loss in [0, 1]
    normalized_loss = penalty / (false_positive_weight + 1.0)

    return normalized_loss


class GroundednessAdalComponent(adal.AdalComponent):
    def __init__(
        self,
        model_client: adal.ModelClient,
        model_kwargs: Dict,
        backward_engine_model_config: Dict = None,
        teacher_model_config: Dict = None,
        text_optimizer_model_config: Dict = None,
    ):
        task = GroundednessTaskPipeline(model_client, model_kwargs)
        # eval_fn = AnswerMatchAcc(type="exact_match").compute_single_item
        eval_fn = groundedness_eval_fn
        loss_fn = adal.EvalFnToTextLoss(
            eval_fn=lambda y, y_gt: weighted_groundedness_loss(
                y, y_gt, false_positive_weight=2.0
            ),
            eval_fn_desc=(
                "Weighted loss to penalize false positives: "
                "1 if y_binary == y_gt_binary, else weighted penalty for FP cases."
            ),
        )

        super().__init__(task=task, eval_fn=eval_fn, loss_fn=loss_fn)
        self.backward_engine_model_config = backward_engine_model_config
        self.teacher_model_config = teacher_model_config
        self.text_optimizer_model_config = text_optimizer_model_config

    def prepare_task(self, sample: XSumData):
        return self.task.call, {
            "premise": sample.query,
            "hypothesis": sample.expected_response,
            "id": sample.id,
        }

    def prepare_loss(self, sample: XSumData, pred: adal.Parameter):
        # prepare the gt and pred for the loss function
        y_gt = adal.Parameter(
            name="y_gt",
            data=sample.expected_score,
            eval_input=sample.expected_score,
            requires_opt=False,
        )

        pred.eval_input = pred.full_response.data[0]
        return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}}

    def prepare_eval(self, sample: XSumData, y_pred: adal.GeneratorOutput):
        # print("ok printing prepare eval")

        # print(f"Y_pred: {y_pred}")

        y_label = -1
        if (
            y_pred
            and y_pred.data
            and len(y_pred.data) > 0
            and isinstance(y_pred.data[0], float)
        ):
            y_label = y_pred.data[0]
        return self.eval_fn, {"y": y_label, "y_gt": sample.expected_score}

    def configure_backward_engine(self):
        super().configure_backward_engine_helper(
            **self.backward_engine_model_config
        )

    def configure_teacher_generator(self):
        super().configure_teacher_generator_helper(**self.teacher_model_config)

    def configure_optimizers(self):
        to = super().configure_text_optimizer_helper(
            **self.text_optimizer_model_config
        )
        do = super().configure_demo_optimizer_helper()  # Add demo optimizer
        return to + do  # Return both text and demo optimizers

In [None]:
def diagnose(
    model_client: adal.ModelClient,
    model_kwargs: Dict,
) -> Dict:
    trainset, valset, testset = (
        train_dataset,
        val_dataset,
        test_dataset,
    )  # use max_samples=10 to test the code
    # use max_samples=10 to test the code

    adal_component = GroundednessAdalComponent(model_client, model_kwargs)
    trainer = adal.Trainer(adaltask=adal_component)
    trainer.diagnose(dataset=trainset, split="train")
    trainer.diagnose(dataset=valset, split="val")
    trainer.diagnose(dataset=testset, split="test")

In [None]:
diagnose(**llama_3_1_model)

In [None]:
def train(
    train_batch_size=4,  # larger batch size is not that effective, probably because of llm's lost in the middle
    raw_shots: int = 0,
    bootstrap_shots: int = 2,
    max_steps=1,
    num_workers=4,
    strategy="random",
    optimization_order="sequential",
    debug=False,
    resume_from_ckpt=None,
    exclude_input_fields_from_bootstrap_demos=False,
):
    adal_component = GroundednessAdalComponent(
        **llama_3_1_model,
        teacher_model_config=gpt_4o_model,
        text_optimizer_model_config=gpt_4o_model,
        backward_engine_model_config=gpt_4o_model,
    )
    print(adal_component)
    trainer = adal.Trainer(
        train_batch_size=train_batch_size,
        adaltask=adal_component,
        strategy=strategy,
        max_steps=max_steps,
        num_workers=num_workers,
        raw_shots=raw_shots,
        bootstrap_shots=bootstrap_shots,
        debug=debug,
        weighted_sampling=True,
        optimization_order=optimization_order,
        exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos,
    )
    print(trainer)

    # train_dataset, val_dataset, test_dataset = load_datasets()
    trainer.fit(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        debug=debug,
        resume_from_ckpt=resume_from_ckpt,
    )

In [None]:
train(
    debug=False,
    max_steps=12,
    strategy="constrained",
    raw_shots=0,
    bootstrap_shots=1,
    exclude_input_fields_from_bootstrap_demos=True,
)