diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 02bc68f60c1..380951c7265 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -40,8 +40,10 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; @@ -788,6 +790,12 @@ public static final class Builder { private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; private Builder() { + try { + region = DefaultAwsRegionProviderChain.builder().build().getRegion(); + } + catch (SdkClientException e) { + logger.warn("Failed to load region from DefaultAwsRegionProviderChain, using US_EAST_1", e); + } } public Builder toolCallingManager(ToolCallingManager toolCallingManager) { diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelTest.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelTest.java new file mode 100644 index 00000000000..4d52311280d --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; + +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class BedrockProxyChatModelTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder; + + @Test + void shouldIgnoreExceptionAndUseDefault() { + try (MockedStatic mocked = mockStatic(DefaultAwsRegionProviderChain.class)) { + when(awsRegionProviderBuilder.build().getRegion()) + .thenThrow(SdkClientException.builder().message("failed load").build()); + mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder); + BedrockProxyChatModel.builder().build(); + } + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 8c84204f319..fc78d8af6bf 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -30,13 +30,16 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.util.ObjectUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.EmitFailureHandler; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; @@ -148,14 +151,12 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv Assert.hasText(modelId, "Model id must not be empty"); Assert.notNull(credentialsProvider, "Credentials provider must not be null"); - Assert.notNull(region, "Region must not be empty"); Assert.notNull(objectMapper, "Object mapper must not be null"); Assert.notNull(timeout, "Timeout must not be null"); this.modelId = modelId; this.objectMapper = objectMapper; - this.region = region; - + this.region = getRegion(region); this.client = BedrockRuntimeClient.builder() .region(this.region) @@ -339,5 +340,17 @@ public record AmazonBedrockInvocationMetrics( @JsonProperty("outputTokenCount") Long outputTokenCount, @JsonProperty("invocationLatency") Long invocationLatency) { } + + private Region getRegion(Region region) { + if (ObjectUtils.isEmpty(region)) { + try { + return DefaultAwsRegionProviderChain.builder().build().getRegion(); + } catch (SdkClientException e) { + throw new IllegalArgumentException("Region is empty and cannot be loaded from DefaultAwsRegionProviderChain: " + e.getMessage(), e); + } + } else { + return region; + } + } } // @formatter:on diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/AbstractBedrockApiTest.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/AbstractBedrockApiTest.java new file mode 100644 index 00000000000..700b47b6bbb --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/AbstractBedrockApiTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.api; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class AbstractBedrockApiTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder; + + @Mock + private AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); + + @Mock + private ObjectMapper objectMapper = mock(ObjectMapper.class); + + @Test + void shouldLoadRegionFromAwsDefaults() { + try (MockedStatic mocked = mockStatic(DefaultAwsRegionProviderChain.class)) { + when(awsRegionProviderBuilder.build().getRegion()).thenReturn(Region.AF_SOUTH_1); + mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder); + AbstractBedrockApi testBedrockApi = new TestBedrockApi("modelId", + awsCredentialsProvider, null, objectMapper, Duration.ofMinutes(5)); + assertThat(testBedrockApi.getRegion()).isEqualTo(Region.AF_SOUTH_1); + } + } + + @Test + void shouldThrowIllegalArgumentIfAwsDefaultsFailed() { + try (MockedStatic mocked = mockStatic(DefaultAwsRegionProviderChain.class)) { + when(awsRegionProviderBuilder.build().getRegion()) + .thenThrow(SdkClientException.builder().message("failed load").build()); + mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder); + assertThatThrownBy(() -> new TestBedrockApi("modelId", awsCredentialsProvider, null, objectMapper, + Duration.ofMinutes(5))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("failed load"); + } + } + + private static class TestBedrockApi extends AbstractBedrockApi { + + protected TestBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + + @Override + protected Object embedding(Object request) { + return null; + } + + @Override + protected Object chatCompletion(Object request) { + return null; + } + + @Override + protected Object internalInvocation(Object request, Class clazz) { + return null; + } + + } + +}