diff --git a/docs/evals.md b/docs/evals.md index 7f698e955b..68ec8184ee 100644 --- a/docs/evals.md +++ b/docs/evals.md @@ -653,10 +653,12 @@ async def main(): print(output_file.read_text()) """ # yaml-language-server: $schema=questions_cases_schema.json + name: null cases: - name: Easy Capital Question inputs: question: What is the capital of France? + context: null metadata: difficulty: easy category: Geography @@ -668,6 +670,7 @@ async def main(): - name: Challenging Landmark Question inputs: question: Which world-famous landmark is located on the banks of the Seine River? + context: null metadata: difficulty: hard category: Landmarks @@ -676,6 +679,7 @@ async def main(): confidence: 0.9 evaluators: - EqualsExpected + evaluators: [] """ ``` @@ -713,11 +717,13 @@ async def main(): """ { "$schema": "questions_cases_schema.json", + "name": null, "cases": [ { "name": "Easy Capital Question", "inputs": { - "question": "What is the capital of France?" + "question": "What is the capital of France?", + "context": null }, "metadata": { "difficulty": "easy", @@ -734,7 +740,8 @@ async def main(): { "name": "Challenging Landmark Question", "inputs": { - "question": "Which world-famous landmark is located on the banks of the Seine River?" + "question": "Which world-famous landmark is located on the banks of the Seine River?", + "context": null }, "metadata": { "difficulty": "hard", @@ -748,7 +755,8 @@ async def main(): "EqualsExpected" ] } - ] + ], + "evaluators": [] } """ ``` diff --git a/pydantic_evals/pydantic_evals/dataset.py b/pydantic_evals/pydantic_evals/dataset.py index ecc7697fdf..75a264b041 100644 --- a/pydantic_evals/pydantic_evals/dataset.py +++ b/pydantic_evals/pydantic_evals/dataset.py @@ -646,7 +646,7 @@ def to_file( context: dict[str, Any] = {'use_short_form': True} if fmt == 'yaml': - dumped_data = self.model_dump(mode='json', by_alias=True, exclude_defaults=True, context=context) + dumped_data = self.model_dump(mode='json', by_alias=True, context=context) content = yaml.dump(dumped_data, sort_keys=False) if schema_ref: # pragma: no branch yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}' @@ -654,7 +654,7 @@ def to_file( path.write_text(content) else: context['$schema'] = schema_ref - json_data = self.model_dump_json(indent=2, by_alias=True, exclude_defaults=True, context=context) + json_data = self.model_dump_json(indent=2, by_alias=True, context=context) path.write_text(json_data + '\n') @classmethod @@ -724,6 +724,7 @@ class Case(BaseModel, extra='forbid'): # pyright: ignore[reportUnusedClass] # evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007 class Dataset(BaseModel, extra='forbid'): + name: str | None = None cases: list[Case] if evaluator_schema_types: # pragma: no branch evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007 diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index 450bcacfd8..d2b2abc008 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Literal import pytest import yaml @@ -863,6 +863,38 @@ async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOut assert (tmp_path / schema).exists() +def test_serializing_parts_with_discriminators(tmp_path: Path): + class Foo(BaseModel): + foo: str + kind: Literal['foo'] = 'foo' + + class Bar(BaseModel): + bar: str + kind: Literal['bar'] = 'bar' + + items = [Foo(foo='foo'), Bar(bar='bar')] + + dataset = Dataset[list[Foo | Bar]](cases=[Case(inputs=items)]) + yaml_path = tmp_path / 'test_cases.yaml' + dataset.to_file(yaml_path) + + loaded_dataset = Dataset[list[Foo | Bar]].from_file(yaml_path) + assert loaded_dataset == snapshot( + Dataset( + name='test_cases', + cases=[ + Case( + name=None, + inputs=[ + Foo(foo='foo'), + Bar(bar='bar'), + ], + ) + ], + ) + ) + + def test_serialization_errors(tmp_path: Path): with pytest.raises(ValueError) as exc_info: Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(tmp_path / 'test_cases.abc')