Skip to content
16 changes: 10 additions & 6 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -244,9 +244,13 @@ class Claude(BaseLlm):

Attributes:
model: The name of the Claude model.
project_id: Optional Google Cloud project ID. If not provided, uses GOOGLE_CLOUD_PROJECT environment variable.
location: Optional Google Cloud location. If not provided, uses GOOGLE_CLOUD_LOCATION environment variable.
"""

model: str = "claude-3-5-sonnet-v2@20241022"
project_id: Optional[str] = None
location: Optional[str] = None

@staticmethod
@override
@@ -289,16 +293,16 @@ async def generate_content_async(

@cached_property
def _anthropic_client(self) -> AnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
):
project = self.project_id or os.environ.get("GOOGLE_CLOUD_PROJECT")
location = self.location or os.environ.get("GOOGLE_CLOUD_LOCATION")

if not project or not location:
raise ValueError(
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
" Anthropic on Vertex."
)

return AnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
project_id=project,
region=location,
)
22 changes: 18 additions & 4 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import sys
from typing import AsyncGenerator
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

@@ -53,9 +54,14 @@ class Gemini(BaseLlm):

Attributes:
model: The name of the Gemini model.
project_id: Optional Google Cloud project ID. If not provided, uses GOOGLE_CLOUD_PROJECT environment variable.
location: Optional Google Cloud location. If not provided, uses GOOGLE_CLOUD_LOCATION environment variable.

"""

model: str = 'gemini-1.5-flash'
project_id: Optional[str] = None
location: Optional[str] = None

@staticmethod
@override
@@ -184,14 +190,22 @@ async def generate_content_async(

@cached_property
def api_client(self) -> Client:
"""Provides the api client.
"""Provides the api client with per-instance configuration support.

Returns:
The api client.
"""
return Client(
http_options=types.HttpOptions(headers=self._tracking_headers)
)
if self.project_id or self.location:
return Client(
vertexai=True,
project=self.project_id,
location=self.location,
http_options=types.HttpOptions(headers=self._tracking_headers),
)
else:
return Client(
http_options=types.HttpOptions(headers=self._tracking_headers)
)

@cached_property
def _api_backend(self) -> GoogleLLMVariant:
102 changes: 102 additions & 0 deletions tests/unittests/models/test_vertex_per_agent_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

from src.google.adk.models.anthropic_llm import Claude
from src.google.adk.models.google_llm import Gemini


def test_claude_custom_config():
claude = Claude(project_id="test-project-claude", location="us-central1")

assert claude.project_id == "test-project-claude"
assert claude.location == "us-central1"


def test_gemini_custom_config():
gemini = Gemini(project_id="test-project-gemini", location="europe-west1")

assert gemini.project_id == "test-project-gemini"
assert gemini.location == "europe-west1"


def test_claude_per_instance_configuration():
claude1 = Claude(project_id="project-1", location="us-central1")
claude2 = Claude(project_id="project-2", location="europe-west1")
claude3 = Claude()

assert claude1.project_id == "project-1"
assert claude1.location == "us-central1"

assert claude2.project_id == "project-2"
assert claude2.location == "europe-west1"

assert claude3.project_id is None
assert claude3.location is None


def test_gemini_per_instance_configuration():
gemini1 = Gemini(project_id="project-1", location="us-central1")
gemini2 = Gemini(project_id="project-2", location="europe-west1")
gemini3 = Gemini()

assert gemini1.project_id == "project-1"
assert gemini1.location == "us-central1"

assert gemini2.project_id == "project-2"
assert gemini2.location == "europe-west1"

assert gemini3.project_id is None
assert gemini3.location is None


def test_backward_compatibility():
claude = Claude()
gemini = Gemini()

assert claude.project_id is None
assert claude.location is None
assert gemini.project_id is None
assert gemini.location is None


@patch.dict(
"os.environ",
{
"GOOGLE_CLOUD_PROJECT": "env-project",
"GOOGLE_CLOUD_LOCATION": "env-location",
},
)
def test_claude_fallback_to_env_vars():
claude = Claude()

cache_key = f"{claude.project_id or 'default'}:{claude.location or 'default'}"
assert cache_key == "default:default"


def test_mixed_configuration():
claude_custom = Claude(project_id="custom-project", location="us-west1")
claude_default = Claude()

key_custom = (
f"{claude_custom.project_id or 'default'}:{claude_custom.location or 'default'}"
)
key_default = (
f"{claude_default.project_id or 'default'}:{claude_default.location or 'default'}"
)

assert key_custom != key_default
assert key_custom == "custom-project:us-west1"
assert key_default == "default:default"