diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 2305d9554d..0d35236576 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -431,32 +431,41 @@ def infer_model(model: Model | KnownModelName) -> Model: raise UserError(f'Unknown model: {model}') -def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient: - """Cached HTTPX async client so multiple agents and calls can share the same client. +def cached_async_http_client(*, provider: str | None = None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient: + """Cached HTTPX async client that creates a separate client for each provider. + + The client is cached based on the provider parameter. If provider is None, it's used for non-provider specific + requests (like downloading images). Multiple agents and calls can share the same client when they use the same provider. There are good reasons why in production you should use a `httpx.AsyncClient` as an async context manager as described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing - examples, it's very useful not to, this allows multiple Agents to use a single client. + examples, it's very useful not to. The default timeouts match those of OpenAI, see . """ - client = _cached_async_http_client(timeout=timeout, connect=connect) + client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect) if client.is_closed: # This happens if the context manager is used, so we need to create a new client. _cached_async_http_client.cache_clear() - client = _cached_async_http_client(timeout=timeout, connect=connect) + client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect) return client @cache -def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient: +def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient: return httpx.AsyncClient( + transport=_cached_async_http_transport(), timeout=httpx.Timeout(timeout=timeout, connect=connect), headers={'User-Agent': get_user_agent()}, ) +@cache +def _cached_async_http_transport() -> httpx.AsyncHTTPTransport: + return httpx.AsyncHTTPTransport() + + @cache def get_user_agent() -> str: """Get the user agent string for the HTTP client.""" diff --git a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py index d2b603f145..1041ae20b5 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py @@ -70,4 +70,5 @@ def __init__( if http_client is not None: self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) else: - self._client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client()) + http_client = cached_async_http_client(provider='anthropic') + self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) diff --git a/pydantic_ai_slim/pydantic_ai/providers/azure.py b/pydantic_ai_slim/pydantic_ai/providers/azure.py index 72831b4f31..0e04a35fe0 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/azure.py +++ b/pydantic_ai_slim/pydantic_ai/providers/azure.py @@ -97,7 +97,7 @@ def __init__( 'Must provide one of the `api_version` argument or the `OPENAI_API_VERSION` environment variable' ) - http_client = http_client or cached_async_http_client() + http_client = http_client or cached_async_http_client(provider='azure') self._client = AsyncAzureOpenAI( azure_endpoint=azure_endpoint, api_key=api_key, diff --git a/pydantic_ai_slim/pydantic_ai/providers/cohere.py b/pydantic_ai_slim/pydantic_ai/providers/cohere.py index 4119be3c56..f7236a9efc 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cohere.py @@ -66,6 +66,5 @@ def __init__( if http_client is not None: self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url) else: - self._client = AsyncClientV2( - api_key=api_key, httpx_client=cached_async_http_client(), base_url=base_url - ) + http_client = cached_async_http_client(provider='cohere') + self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url) diff --git a/pydantic_ai_slim/pydantic_ai/providers/deepseek.py b/pydantic_ai_slim/pydantic_ai/providers/deepseek.py index 850a1fa44f..2906ed1cfd 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/deepseek.py +++ b/pydantic_ai_slim/pydantic_ai/providers/deepseek.py @@ -65,4 +65,5 @@ def __init__( elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: - self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client()) + http_client = cached_async_http_client(provider='deepseek') + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) diff --git a/pydantic_ai_slim/pydantic_ai/providers/google_gla.py b/pydantic_ai_slim/pydantic_ai/providers/google_gla.py index c545c22e60..282f33a50f 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google_gla.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google_gla.py @@ -39,7 +39,7 @@ def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | 'to use the Google GLA provider.' ) - self._client = http_client or cached_async_http_client() + self._client = http_client or cached_async_http_client(provider='google-gla') self._client.base_url = self.base_url # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest self._client.headers['X-Goog-Api-Key'] = api_key diff --git a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py index af4af520b8..e4fe923237 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py @@ -97,7 +97,7 @@ def __init__( if service_account_file and service_account_info: raise ValueError('Only one of `service_account_file` or `service_account_info` can be provided.') - self._client = http_client or cached_async_http_client() + self._client = http_client or cached_async_http_client(provider='google-vertex') self.service_account_file = service_account_file self.service_account_info = service_account_info self.project_id = project_id diff --git a/pydantic_ai_slim/pydantic_ai/providers/groq.py b/pydantic_ai_slim/pydantic_ai/providers/groq.py index 1e33acdc2a..0a61e58b44 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/groq.py +++ b/pydantic_ai_slim/pydantic_ai/providers/groq.py @@ -71,6 +71,5 @@ def __init__( elif http_client is not None: self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client) else: - self._client = AsyncGroq( - base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client() - ) + http_client = cached_async_http_client(provider='groq') + self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client) diff --git a/pydantic_ai_slim/pydantic_ai/providers/mistral.py b/pydantic_ai_slim/pydantic_ai/providers/mistral.py index 0d02e2bdbc..e491ecf2ea 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/providers/mistral.py @@ -69,4 +69,5 @@ def __init__( elif http_client is not None: self._client = Mistral(api_key=api_key, async_client=http_client) else: - self._client = Mistral(api_key=api_key, async_client=cached_async_http_client()) + http_client = cached_async_http_client(provider='mistral') + self._client = Mistral(api_key=api_key, async_client=http_client) diff --git a/pydantic_ai_slim/pydantic_ai/providers/openai.py b/pydantic_ai_slim/pydantic_ai/providers/openai.py index f9037a1eea..6b3cb91848 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/openai.py +++ b/pydantic_ai_slim/pydantic_ai/providers/openai.py @@ -63,4 +63,5 @@ def __init__( elif http_client is not None: self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) else: - self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client()) + http_client = cached_async_http_client(provider='openai') + self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) diff --git a/tests/conftest.py b/tests/conftest.py index f5e02d8a67..f5f0abe1eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -200,7 +200,18 @@ def vcr_config(): @pytest.fixture(autouse=True) async def close_cached_httpx_client() -> AsyncIterator[None]: yield - await cached_async_http_client().aclose() + for provider in [ + 'openai', + 'anthropic', + 'azure', + 'google-gla', + 'google-vertex', + 'groq', + 'mistral', + 'cohere', + 'deepseek', + ]: + await cached_async_http_client(provider=provider).aclose() @pytest.fixture(scope='session')