diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0e5583d231..838436f255 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -62,7 +62,7 @@ jobs: with: args: check --fix-only - name: Run tests with pytest - run: poetry run pytest tests/ + run: poetry run pytest tests/ --ignore=tests/reliability build_poetry: name: Build Poetry diff --git a/tests/reliability/README.md b/tests/reliability/README.md new file mode 100644 index 0000000000..f660d26b25 --- /dev/null +++ b/tests/reliability/README.md @@ -0,0 +1,58 @@ +# DSPy Reliability Tests + +This directory contains reliability tests for DSPy programs. The purpose of these tests is to verify that DSPy programs reliabily produce expected outputs across multiple large language models (LLMs), regardless of model size or capability. These tests are designed to ensure that DSPy programs maintain robustness and accuracy across diverse LLM configurations. + +### Overview + +Each test in this directory executes a DSPy program using various LLMs. By running the same tests across different models, these tests help validate that DSPy programs handle a wide range of inputs effectively and produce reliable outputs, even in cases where the model might struggle with the input or task. + +### Key Features + +- **Diverse LLMs**: Each DSPy program is tested with multiple LLMs, ranging from smaller models to more advanced, high-performance models. This approach allows us to assess the consistency and generality of DSPy program outputs across different model capabilities. +- **Challenging and Adversarial Tests**: Some of the tests are intentionally challenging or adversarial, crafted to push the boundaries of DSPy. These challenging cases allow us to gauge the robustness of DSPy and identify areas for potential improvement. +- **Cross-Model Compatibility**: By testing with different LLMs, we aim to ensure that DSPy programs perform well across model types and configurations, reducing model-specific edge cases and enhancing program versatility. + +### Running the Tests + +- First, populate the configuration file `reliability_tests_conf.yaml` (located in this directory) with the necessary LiteLLM model/provider names and access credentials for 1. each LLM you want to test and 2. the LLM judge that you want to use for assessing the correctness of outputs in certain test cases. These should be placed in the `litellm_params` section for each model in the defined `model_list`. You can also use `litellm_params` to specify values for LLM hyperparameters like `temperature`. Any model that lacks configured `litellm_params` in the configuration file will be ignored during testing. + + The configuration must also specify a DSPy adapter to use when testing, e.g. `"chat"` (for `dspy.ChatAdapter`) or `"json"` (for `dspy.JSONAdapter`). + + An example of `reliability_tests_conf.yaml`: + + ```yaml + adapter: chat + model_list: + # The model to use for judging the correctness of program + # outputs throughout reliability test suites. We recommend using + # a high quality model as the judge, such as OpenAI GPT-4o + - model_name: "judge" + litellm_params: + model: "openai/gpt-4o" + api_key: "" + - model_name: "gpt-4o" + litellm_params: + model: "openai/gpt-4o" + api_key: "" + - model_name: "claude-3.5-sonnet" + litellm_params: + model: "anthropic/claude-3.5" + api_key: "" + +- Second, to run the tests, run the following command from this directory: + + ```bash + pytest . + ``` + + This will execute all tests for the configured models and display detailed results for each model configuration. Tests are set up to mark expected failures for known challenging cases where a specific model might struggle, while actual (unexpected) DSPy reliability issues are flagged as failures (see below). + +### Known Failing Models + +Some tests may be expected to fail with certain models, especially in challenging cases. These known failures are logged but do not affect the overall test result. This setup allows us to keep track of model-specific limitations without obstructing general test outcomes. Models that are known to fail a particular test case are specified using the `@known_failing_models` decorator. For example: + +``` +@known_failing_models(["llama-3.2-3b-instruct"]) +def test_program_with_complex_deeply_nested_output_structure(): + ... +``` diff --git a/tests/reliability/__init__.py b/tests/reliability/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/reliability/conftest.py b/tests/reliability/conftest.py new file mode 100644 index 0000000000..5f6d662fc0 --- /dev/null +++ b/tests/reliability/conftest.py @@ -0,0 +1,84 @@ +import os + +import pytest + +import dspy +from tests.conftest import clear_settings +from tests.reliability.utils import parse_reliability_conf_yaml + +# Standard list of models that should be used for periodic DSPy reliability testing +MODEL_LIST = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-o1-preview", + "gpt-o1-mini", + "claude-3.5-sonnet", + "claude-3.5-haiku", + "gemini-1.5-pro", + "gemini-1.5-flash", + "llama-3.1-405b-instruct", + "llama-3.1-70b-instruct", + "llama-3.1-8b-instruct", + "llama-3.2-3b-instruct", +] + + +def pytest_generate_tests(metafunc): + """ + Hook to parameterize reliability test cases with each model defined in the + reliability tests YAML configuration + """ + known_failing_models = getattr(metafunc.function, "_known_failing_models", []) + + if "configure_model" in metafunc.fixturenames: + params = [(model, model in known_failing_models) for model in MODEL_LIST] + ids = [f"{model}" for model, _ in params] # Custom IDs for display + metafunc.parametrize("configure_model", params, indirect=True, ids=ids) + + +@pytest.fixture(autouse=True) +def configure_model(request): + """ + Fixture to configure the DSPy library with a particular configured model and adapter + before executing a test case. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + conf_path = os.path.join(module_dir, "reliability_conf.yaml") + reliability_conf = parse_reliability_conf_yaml(conf_path) + + if reliability_conf.adapter.lower() == "chat": + adapter = dspy.ChatAdapter() + elif reliability_conf.adapter.lower() == "json": + adapter = dspy.JSONAdapter() + else: + raise ValueError(f"Unknown adapter specification '{adapter}' in reliability_conf.yaml") + + model_name, should_ignore_failure = request.param + model_params = reliability_conf.models.get(model_name) + if model_params: + lm = dspy.LM(**model_params) + dspy.configure(lm=lm, adapter=adapter) + else: + pytest.skip( + f"Skipping test because no reliability testing YAML configuration was found" f" for model {model_name}." + ) + + # Store `should_ignore_failure` flag on the request node for use in post-test handling + request.node.should_ignore_failure = should_ignore_failure + request.node.model_name = model_name + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + """ + Hook to conditionally ignore failures in a given test case for known failing models. + """ + outcome = yield + rep = outcome.get_result() + + should_ignore_failure = getattr(item, "should_ignore_failure", False) + + if should_ignore_failure and rep.failed: + rep.outcome = "passed" + rep.wasxfail = "Ignoring failure for known failing model" diff --git a/tests/reliability/reliability_conf.yaml b/tests/reliability/reliability_conf.yaml new file mode 100644 index 0000000000..ef33cd94e2 --- /dev/null +++ b/tests/reliability/reliability_conf.yaml @@ -0,0 +1,75 @@ +adapter: chat +model_list: + # The model to use for judging the correctness of program + # outputs throughout reliability test suites. We recommend using + # a high quality model as the judge, such as OpenAI GPT-4o + - model_name: "judge" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gpt-4o" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gpt-4o-mini" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gpt-4-turbo" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gpt-o1" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gpt-o1-mini" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "claude-3.5-sonnet" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "claude-3.5-haiku" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gemini-1.5-pro" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "gemini-1.5-flash" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "llama-3.1-405b-instruct" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "llama-3.1-70b-instruct" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "llama-3.1-8b-instruct" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" + - model_name: "llama-3.2-3b-instruct" + litellm_params: + # model: "/" + # api_key: "api key" + # api_base: "" diff --git a/tests/reliability/test_pydantic_models.py b/tests/reliability/test_pydantic_models.py new file mode 100644 index 0000000000..5292d2036b --- /dev/null +++ b/tests/reliability/test_pydantic_models.py @@ -0,0 +1,89 @@ +from enum import Enum +from typing import List + +import pydantic + +import dspy +from tests.reliability.utils import assert_program_output_correct, known_failing_models + + +def test_qa_with_pydantic_answer_model(): + class Answer(pydantic.BaseModel): + value: str + certainty: float = pydantic.Field( + description="A value between 0 and 1 indicating the model's confidence in the answer." + ) + comments: List[str] = pydantic.Field( + description="At least two comments providing additional details about the answer." + ) + + class QA(dspy.Signature): + question: str = dspy.InputField() + answer: Answer = dspy.OutputField() + + program = dspy.Predict(QA) + answer = program(question="What is the capital of France?").answer + + assert_program_output_correct( + program_output=answer.value, + grading_guidelines="The answer should be Paris. Answer should not contain extraneous information.", + ) + assert_program_output_correct( + program_output=answer.comments, grading_guidelines="The comments should be relevant to the answer" + ) + assert answer.certainty >= 0 + assert answer.certainty <= 1 + assert len(answer.comments) >= 2 + + +def test_color_classification_using_enum(): + Color = Enum("Color", ["RED", "GREEN", "BLUE"]) + + class Colorful(dspy.Signature): + text: str = dspy.InputField() + color: Color = dspy.OutputField() + + program = dspy.Predict(Colorful) + color = program(text="The sky is blue").color + + assert color == Color.BLUE + + +def test_entity_extraction_with_multiple_primitive_outputs(): + class ExtractEntityFromDescriptionOutput(pydantic.BaseModel): + entity_hu: str = pydantic.Field(description="The extracted entity in Hungarian, cleaned and lowercased.") + entity_en: str = pydantic.Field(description="The English translation of the extracted Hungarian entity.") + is_inverted: bool = pydantic.Field( + description="Boolean flag indicating if the input is connected in an inverted way." + ) + categories: str = pydantic.Field(description="English categories separated by '|' to which the entity belongs.") + review: bool = pydantic.Field( + description="Boolean flag indicating low confidence or uncertainty in the extraction." + ) + + class ExtractEntityFromDescription(dspy.Signature): + """Extract an entity from a Hungarian description, provide its English translation, categories, and an inverted flag.""" + + description: str = dspy.InputField(description="The input description in Hungarian.") + entity: ExtractEntityFromDescriptionOutput = dspy.OutputField( + description="The extracted entity and its properties." + ) + + program = dspy.ChainOfThought(ExtractEntityFromDescription) + + extracted_entity = program(description="A kávé egy növényi eredetű ital, amelyet a kávébabból készítenek.").entity + assert_program_output_correct( + program_output=extracted_entity.entity_hu, + grading_guidelines="The translation of the text into English should be equivalent to 'coffee'", + ) + assert_program_output_correct( + program_output=extracted_entity.entity_hu, + grading_guidelines="The text should be equivalent to 'coffee'", + ) + assert_program_output_correct( + program_output=extracted_entity.categories, + grading_guidelines=( + "The text should contain English language categories that apply to the word 'coffee'." + " The categories should be separated by the character '|'." + ), + ) diff --git a/tests/reliability/utils.py b/tests/reliability/utils.py new file mode 100644 index 0000000000..223e6a8e3b --- /dev/null +++ b/tests/reliability/utils.py @@ -0,0 +1,118 @@ +import os +from contextlib import contextmanager +from functools import lru_cache, wraps +from typing import Any, Dict, List + +import pydantic +import pytest +import yaml + +import dspy + +JUDGE_MODEL_NAME = "judge" + + +def assert_program_output_correct(program_output: Any, grading_guidelines: str): + """ + With the help of an LLM judge, assert that the specified output of a DSPy program is correct, + according to the specified grading guidelines. + + Args: + program_output: The output of a DSPy program. + grading_guidelines: The grading guidelines for judging the correctness of the + program output. + """ + with _judge_dspy_configuration(): + judge_response = _get_judge_program()( + program_output=str(program_output), + guidelines=grading_guidelines, + ).judge_response + assert judge_response.correct, f"Output: {program_output}. Reason incorrect: {judge_response.justification}" + + +def known_failing_models(models: List[str]): + """ + Decorator to allow specific test cases to fail for certain models. This is useful when a + model is known to be unable to perform a specific task (e.g. output formatting with complex + schemas) to the required standard. + + Args: + models: List of model names for which the test case is allowed to fail. + """ + + def decorator(test_func): + test_func._known_failing_models = models + + @wraps(test_func) + def wrapper(*args, **kwargs): + return test_func(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def _judge_dspy_configuration(): + module_dir = os.path.dirname(os.path.abspath(__file__)) + conf_path = os.path.join(module_dir, "reliability_conf.yaml") + reliability_conf = parse_reliability_conf_yaml(conf_path) + judge_params = reliability_conf.models.get(JUDGE_MODEL_NAME) + if judge_params is None: + raise ValueError(f"No LiteLLM configuration found for judge model: {JUDGE_MODEL_NAME}") + + with dspy.settings.context(lm=dspy.LM(**judge_params)): + yield + + +def _get_judge_program(): + class JudgeResponse(pydantic.BaseModel): + correct: bool = pydantic.Field("Whether or not the judge output is correct") + justification: str = pydantic.Field("Justification for the correctness of the judge output") + + class JudgeSignature(dspy.Signature): + program_output: str = dspy.InputField(description="The output of an AI program / model to be judged") + guidelines: str = dspy.InputField( + description=( + "Grading guidelines for judging the correctness of the program output." + " If the output satisfies the guidelines, the judge will return correct=True." + ) + ) + judge_response: JudgeResponse = dspy.OutputField() + + return dspy.Predict(JudgeSignature) + + +class ReliabilityTestConf(pydantic.BaseModel): + adapter: str + models: Dict[str, Any] + + +@lru_cache(maxsize=None) +def parse_reliability_conf_yaml(conf_file_path: str) -> ReliabilityTestConf: + try: + with open(conf_file_path, "r") as file: + conf = yaml.safe_load(file) + + model_dict = {} + for conf_entry in conf["model_list"]: + model_name = conf_entry.get("model_name") + if model_name is None: + raise ValueError("Model name missing in reliability_conf.yaml") + + litellm_params = conf_entry.get("litellm_params") + if litellm_params is not None: + model_dict[model_name] = litellm_params + else: + print( + f"Skipping all test cases for model {model_name} without LiteLLM parameters" + f" ('litellm_params' section of conf file is missing)." + ) + + adapter = conf.get("adapter") + if adapter is None: + raise ValueError("No adapter configuration found in reliability_conf.yaml") + + return ReliabilityTestConf(adapter=adapter, models=model_dict) + except Exception as e: + raise ValueError(f"Error parsing LiteLLM configuration file: {conf_file_path}") from e