Skip to content

Commit 90fc3fd

Browse files
sviluppjan-simlDouweM
authored
Fix deprecated kwargs validation to prevent silent failures (#2047)
Co-authored-by: J S <jan@swap-commerce.com> Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent 44d6da3 commit 90fc3fd

File tree

4 files changed

+154
-14
lines changed

4 files changed

+154
-14
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
from pydantic_graph._utils import AbstractSpan
3333

34+
from . import exceptions
35+
3436
AbstractSpan = AbstractSpan
3537

3638
if TYPE_CHECKING:
@@ -415,6 +417,20 @@ def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str
415417
return rewritten_schemas, all_defs
416418

417419

420+
def validate_empty_kwargs(_kwargs: dict[str, Any]) -> None:
421+
"""Validate that no unknown kwargs remain after processing.
422+
423+
Args:
424+
_kwargs: Dictionary of remaining kwargs after specific ones have been processed.
425+
426+
Raises:
427+
UserError: If any unknown kwargs remain.
428+
"""
429+
if _kwargs:
430+
unknown_kwargs = ', '.join(f'`{k}`' for k in _kwargs.keys())
431+
raise exceptions.UserError(f'Unknown keyword arguments: {unknown_kwargs}')
432+
433+
418434
def strip_markdown_fences(text: str) -> str:
419435
if text.startswith('{'):
420436
return text

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -294,41 +294,43 @@ def __init__(
294294
self.name = name
295295
self.model_settings = model_settings
296296

297-
if 'result_type' in _deprecated_kwargs: # pragma: no cover
298-
if output_type is not str:
297+
if 'result_type' in _deprecated_kwargs:
298+
if output_type is not str: # pragma: no cover
299299
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
300300
warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning)
301-
output_type = _deprecated_kwargs['result_type']
301+
output_type = _deprecated_kwargs.pop('result_type')
302302

303303
self.output_type = output_type
304304

305305
self.instrument = instrument
306306

307307
self._deps_type = deps_type
308308

309-
self._deprecated_result_tool_name = _deprecated_kwargs.get('result_tool_name')
310-
if self._deprecated_result_tool_name is not None: # pragma: no cover
309+
self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None)
310+
if self._deprecated_result_tool_name is not None:
311311
warnings.warn(
312312
'`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead',
313313
DeprecationWarning,
314314
)
315315

316-
self._deprecated_result_tool_description = _deprecated_kwargs.get('result_tool_description')
317-
if self._deprecated_result_tool_description is not None: # pragma: no cover
316+
self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None)
317+
if self._deprecated_result_tool_description is not None:
318318
warnings.warn(
319319
'`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead',
320320
DeprecationWarning,
321321
)
322-
result_retries = _deprecated_kwargs.get('result_retries')
323-
if result_retries is not None: # pragma: no cover
324-
if output_retries is not None:
322+
result_retries = _deprecated_kwargs.pop('result_retries', None)
323+
if result_retries is not None:
324+
if output_retries is not None: # pragma: no cover
325325
raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.')
326326
warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
327327
output_retries = result_retries
328328

329329
default_output_mode = (
330330
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
331331
)
332+
_utils.validate_empty_kwargs(_deprecated_kwargs)
333+
332334
self._output_schema = _output.OutputSchema[OutputDataT].build(
333335
output_type,
334336
default_mode=default_output_mode,
@@ -469,7 +471,9 @@ async def main():
469471
if output_type is not str:
470472
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
471473
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
472-
output_type = _deprecated_kwargs['result_type']
474+
output_type = _deprecated_kwargs.pop('result_type')
475+
476+
_utils.validate_empty_kwargs(_deprecated_kwargs)
473477

474478
async with self.iter(
475479
user_prompt=user_prompt,
@@ -635,7 +639,9 @@ async def main():
635639
if output_type is not str:
636640
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
637641
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
638-
output_type = _deprecated_kwargs['result_type']
642+
output_type = _deprecated_kwargs.pop('result_type')
643+
644+
_utils.validate_empty_kwargs(_deprecated_kwargs)
639645

640646
deps = self._get_deps(deps)
641647
new_message_index = len(message_history) if message_history else 0
@@ -872,7 +878,9 @@ def run_sync(
872878
if output_type is not str:
873879
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
874880
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
875-
output_type = _deprecated_kwargs['result_type']
881+
output_type = _deprecated_kwargs.pop('result_type')
882+
883+
_utils.validate_empty_kwargs(_deprecated_kwargs)
876884

877885
return get_event_loop().run_until_complete(
878886
self.run(
@@ -988,7 +996,9 @@ async def main():
988996
if output_type is not str:
989997
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
990998
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
991-
output_type = _deprecated_kwargs['result_type']
999+
output_type = _deprecated_kwargs.pop('result_type')
1000+
1001+
_utils.validate_empty_kwargs(_deprecated_kwargs)
9921002

9931003
yielded = False
9941004
async with self.iter(

tests/test_agent.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3125,3 +3125,79 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
31253125

31263126
with pytest.raises(UserError, match='Output tools are not supported by the model.'):
31273127
agent.run_sync('Hello')
3128+
3129+
3130+
def test_deprecated_kwargs_validation_agent_init():
3131+
"""Test that invalid kwargs raise UserError in Agent constructor."""
3132+
with pytest.raises(UserError, match='Unknown keyword arguments: `usage_limits`'):
3133+
Agent('test', usage_limits='invalid') # type: ignore[call-arg]
3134+
3135+
with pytest.raises(UserError, match='Unknown keyword arguments: `invalid_kwarg`'):
3136+
Agent('test', invalid_kwarg='value') # type: ignore[call-arg]
3137+
3138+
with pytest.raises(UserError, match='Unknown keyword arguments: `foo`, `bar`'):
3139+
Agent('test', foo='value1', bar='value2') # type: ignore[call-arg]
3140+
3141+
3142+
def test_deprecated_kwargs_validation_agent_run():
3143+
"""Test that invalid kwargs raise UserError in Agent.run method."""
3144+
agent = Agent('test')
3145+
3146+
with pytest.raises(UserError, match='Unknown keyword arguments: `invalid_kwarg`'):
3147+
agent.run_sync('test', invalid_kwarg='value') # type: ignore[call-arg]
3148+
3149+
with pytest.raises(UserError, match='Unknown keyword arguments: `foo`, `bar`'):
3150+
agent.run_sync('test', foo='value1', bar='value2') # type: ignore[call-arg]
3151+
3152+
3153+
def test_deprecated_kwargs_still_work():
3154+
"""Test that valid deprecated kwargs still work with warnings."""
3155+
import warnings
3156+
3157+
with warnings.catch_warnings(record=True) as w:
3158+
warnings.simplefilter('always')
3159+
3160+
agent = Agent('test', result_type=str) # type: ignore[call-arg]
3161+
assert len(w) == 1
3162+
assert issubclass(w[0].category, DeprecationWarning)
3163+
assert '`result_type` is deprecated' in str(w[0].message)
3164+
assert agent.output_type is str
3165+
3166+
with warnings.catch_warnings(record=True) as w:
3167+
warnings.simplefilter('always')
3168+
3169+
agent = Agent('test', result_tool_name='test_tool') # type: ignore[call-arg]
3170+
assert len(w) == 1
3171+
assert issubclass(w[0].category, DeprecationWarning)
3172+
assert '`result_tool_name` is deprecated' in str(w[0].message)
3173+
3174+
with warnings.catch_warnings(record=True) as w:
3175+
warnings.simplefilter('always')
3176+
3177+
agent = Agent('test', result_tool_description='test description') # type: ignore[call-arg]
3178+
assert len(w) == 1
3179+
assert issubclass(w[0].category, DeprecationWarning)
3180+
assert '`result_tool_description` is deprecated' in str(w[0].message)
3181+
3182+
with warnings.catch_warnings(record=True) as w:
3183+
warnings.simplefilter('always')
3184+
3185+
agent = Agent('test', result_retries=3) # type: ignore[call-arg]
3186+
assert len(w) == 1
3187+
assert issubclass(w[0].category, DeprecationWarning)
3188+
assert '`result_retries` is deprecated' in str(w[0].message)
3189+
3190+
3191+
def test_deprecated_kwargs_mixed_valid_invalid():
3192+
"""Test that mix of valid deprecated and invalid kwargs raises error for invalid ones."""
3193+
import warnings
3194+
3195+
with pytest.raises(UserError, match='Unknown keyword arguments: `usage_limits`'):
3196+
with warnings.catch_warnings():
3197+
warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_type
3198+
Agent('test', result_type=str, usage_limits='invalid') # type: ignore[call-arg]
3199+
3200+
with pytest.raises(UserError, match='Unknown keyword arguments: `foo`, `bar`'):
3201+
with warnings.catch_warnings():
3202+
warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_tool_name
3203+
Agent('test', result_tool_name='test', foo='value1', bar='value2') # type: ignore[call-arg]

tests/test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
merge_json_schema_defs,
2121
run_in_executor,
2222
strip_markdown_fences,
23+
validate_empty_kwargs,
2324
)
2425

2526
from .models.mock_async_stream import MockAsyncStream
@@ -503,3 +504,40 @@ def test_strip_markdown_fences():
503504
== '{"foo": "bar"}'
504505
)
505506
assert strip_markdown_fences('No JSON to be found') == 'No JSON to be found'
507+
508+
509+
def test_validate_empty_kwargs_empty():
510+
"""Test that empty dict passes validation."""
511+
validate_empty_kwargs({})
512+
513+
514+
def test_validate_empty_kwargs_with_unknown():
515+
"""Test that unknown kwargs raise UserError."""
516+
with pytest.raises(UserError, match='Unknown keyword arguments: `unknown_arg`'):
517+
validate_empty_kwargs({'unknown_arg': 'value'})
518+
519+
520+
def test_validate_empty_kwargs_multiple_unknown():
521+
"""Test that multiple unknown kwargs are properly formatted."""
522+
with pytest.raises(UserError, match='Unknown keyword arguments: `arg1`, `arg2`'):
523+
validate_empty_kwargs({'arg1': 'value1', 'arg2': 'value2'})
524+
525+
526+
def test_validate_empty_kwargs_message_format():
527+
"""Test that the error message format matches expected pattern."""
528+
with pytest.raises(UserError) as exc_info:
529+
validate_empty_kwargs({'test_arg': 'test_value'})
530+
531+
assert 'Unknown keyword arguments: `test_arg`' in str(exc_info.value)
532+
533+
534+
def test_validate_empty_kwargs_preserves_order():
535+
"""Test that multiple kwargs preserve order in error message."""
536+
kwargs = {'first': '1', 'second': '2', 'third': '3'}
537+
with pytest.raises(UserError) as exc_info:
538+
validate_empty_kwargs(kwargs)
539+
540+
error_msg = str(exc_info.value)
541+
assert '`first`' in error_msg
542+
assert '`second`' in error_msg
543+
assert '`third`' in error_msg

0 commit comments

Comments
 (0)