Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import os
from typing import overload

from pydantic_ai.exceptions import UserError
Expand All @@ -8,6 +9,7 @@
try:
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from botocore.exceptions import NoRegionError
except ImportError as _import_error:
raise ImportError(
Expand Down Expand Up @@ -42,6 +44,8 @@ def __init__(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
aws_read_timeout: float | None = None,
aws_connect_timeout: float | None = None,
) -> None: ...

def __init__(
Expand All @@ -52,6 +56,8 @@ def __init__(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
aws_read_timeout: float | None = None,
aws_connect_timeout: float | None = None,
) -> None:
"""Initialize the Bedrock provider.

Expand All @@ -61,17 +67,22 @@ def __init__(
aws_access_key_id: The AWS access key ID.
aws_secret_access_key: The AWS secret access key.
aws_session_token: The AWS session token.
aws_read_timeout: The read timeout for Bedrock client.
aws_connect_timeout: The connect timeout for Bedrock client.
"""
if bedrock_client is not None:
self._client = bedrock_client
else:
try:
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
self._client = boto3.client( # type: ignore[reportUnknownMemberType]
'bedrock-runtime',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region_name,
config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
)
except NoRegionError as exc: # pragma: no cover
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
17 changes: 17 additions & 0 deletions tests/providers/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import cast

import pytest

from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient

from pydantic_ai.providers.bedrock import BedrockProvider


Expand All @@ -15,3 +19,16 @@ def test_bedrock_provider(env: TestEnv):
assert isinstance(provider, BedrockProvider)
assert provider.name == 'bedrock'
assert provider.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com'


def test_bedrock_provider_timeout(env: TestEnv):
env.set('AWS_DEFAULT_REGION', 'us-east-1')
env.set('AWS_READ_TIMEOUT', '1')
env.set('AWS_CONNECT_TIMEOUT', '1')
provider = BedrockProvider()
assert isinstance(provider, BedrockProvider)
assert provider.name == 'bedrock'

config = cast(BedrockRuntimeClient, provider.client).meta.config
assert config.read_timeout == 1 # type: ignore
assert config.connect_timeout == 1 # type: ignore