diff --git a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py index 79c4e679ba..1a8980e057 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import os from typing import overload from pydantic_ai.exceptions import UserError @@ -8,6 +9,7 @@ try: import boto3 from botocore.client import BaseClient + from botocore.config import Config from botocore.exceptions import NoRegionError except ImportError as _import_error: raise ImportError( @@ -42,6 +44,8 @@ def __init__( aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, + aws_read_timeout: float | None = None, + aws_connect_timeout: float | None = None, ) -> None: ... def __init__( @@ -52,6 +56,8 @@ def __init__( aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, + aws_read_timeout: float | None = None, + aws_connect_timeout: float | None = None, ) -> None: """Initialize the Bedrock provider. @@ -61,17 +67,22 @@ def __init__( aws_access_key_id: The AWS access key ID. aws_secret_access_key: The AWS secret access key. aws_session_token: The AWS session token. + aws_read_timeout: The read timeout for Bedrock client. + aws_connect_timeout: The connect timeout for Bedrock client. """ if bedrock_client is not None: self._client = bedrock_client else: try: + read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300)) + connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60)) self._client = boto3.client( # type: ignore[reportUnknownMemberType] 'bedrock-runtime', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region_name, + config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout), ) except NoRegionError as exc: # pragma: no cover raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index 5fce0072c5..9f62fe6879 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -1,8 +1,12 @@ +from typing import cast + import pytest from ..conftest import TestEnv, try_import with try_import() as imports_successful: + from mypy_boto3_bedrock_runtime import BedrockRuntimeClient + from pydantic_ai.providers.bedrock import BedrockProvider @@ -15,3 +19,16 @@ def test_bedrock_provider(env: TestEnv): assert isinstance(provider, BedrockProvider) assert provider.name == 'bedrock' assert provider.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com' + + +def test_bedrock_provider_timeout(env: TestEnv): + env.set('AWS_DEFAULT_REGION', 'us-east-1') + env.set('AWS_READ_TIMEOUT', '1') + env.set('AWS_CONNECT_TIMEOUT', '1') + provider = BedrockProvider() + assert isinstance(provider, BedrockProvider) + assert provider.name == 'bedrock' + + config = cast(BedrockRuntimeClient, provider.client).meta.config + assert config.read_timeout == 1 # type: ignore + assert config.connect_timeout == 1 # type: ignore