diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 7f15006d8b6..7bd54073e28 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -24,3 +24,22 @@ jobs: - name: Run tests run: | ./mvnw --batch-mode test + spring_7: + name: Compile with Spring 7 + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'spring-projects' }} + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Run tests + run: | + ./mvnw --batch-mode -Dspring-framework.version=7.0.0-RC2 -Dkotlin.version=2.2.0 compile + diff --git a/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java b/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java index 1d7eed38da9..1d37c94ae65 100644 --- a/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java +++ b/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.ai.retry.autoconfigure; import java.io.IOException; +import java.net.URI; import java.nio.charset.StandardCharsets; import org.slf4j.Logger; @@ -30,6 +31,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpMethod; import org.springframework.http.client.ClientHttpResponse; import org.springframework.lang.NonNull; import org.springframework.retry.RetryCallback; @@ -87,6 +89,14 @@ public boolean hasError(@NonNull ClientHttpResponse response) throws IOException } @Override + public void handleError(URI url, HttpMethod method, ClientHttpResponse response) throws IOException { + handleError(response); + } + + // On purposes commented out so that the code can compile both with Spring 6 + // and Spring 7 + // @Override + @SuppressWarnings("removal") public void handleError(@NonNull ClientHttpResponse response) throws IOException { if (!response.getStatusCode().isError()) { return; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 497d43acc4a..a9713034dc9 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -176,7 +176,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletio return this.restClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(additionalHttpHeader); + additionalHttpHeader.forEach(headers::addAll); addDefaultHeadersIfMissing(headers); }) .body(chatRequest) @@ -217,7 +217,7 @@ public Flux chatCompletionStream(ChatCompletionRequest c return this.webClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(additionalHttpHeader); + additionalHttpHeader.forEach(headers::addAll); addDefaultHeadersIfMissing(headers); }) // @formatter:off .body(Mono.just(chatRequest), ChatCompletionRequest.class) @@ -256,7 +256,8 @@ public Flux chatCompletionStream(ChatCompletionRequest c } private void addDefaultHeadersIfMissing(HttpHeaders headers) { - if (!headers.containsKey(HEADER_X_API_KEY)) { + List apiKeyHeaders = headers.get(HEADER_X_API_KEY); + if (apiKeyHeaders == null) { String apiKeyValue = this.apiKey.getValue(); if (StringUtils.hasText(apiKeyValue)) { headers.add(HEADER_X_API_KEY, apiKeyValue); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java index 13415829854..af8682f5848 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java @@ -94,7 +94,7 @@ public DeepSeekApi(String baseUrl, ApiKey apiKey, MultiValueMap Consumer finalHeaders = h -> { h.setBearerAuth(apiKey.getValue()); h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(headers); + headers.forEach(h::addAll); }; this.restClient = restClientBuilder.baseUrl(baseUrl) .defaultHeaders(finalHeaders) @@ -153,7 +153,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return this.webClient.post() .uri(this.getEndpoint(chatRequest)) - .headers(headers -> headers.addAll(additionalHttpHeader)) + .headers(headers -> additionalHttpHeader.forEach(headers::addAll)) .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) diff --git a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java index 10ce0349070..46495c29088 100644 --- a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java +++ b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java @@ -70,7 +70,7 @@ private ElevenLabsApi(String baseUrl, ApiKey apiKey, MultiValueMapgeneratedAnnotation}} +@Component("{{invokerPackage}}.ApiClient") +public class ApiClient { + public enum CollectionFormat { + CSV(","), TSV("\t"), SSV(" "), PIPES("|"), MULTI(null); + + private final String separator; + private CollectionFormat(String separator) { + this.separator = separator; + } + + private String collectionToString(Collection collection) { + return StringUtils.collectionToDelimitedString(collection, separator); + } + } + + private boolean debugging = false; + + private HttpHeaders defaultHeaders = new HttpHeaders(); + + private String basePath = "{{basePath}}"; + + private RestTemplate restTemplate; + + private Map authentications; + + private DateFormat dateFormat; + + public ApiClient() { + this.restTemplate = buildRestTemplate(); + init(); + } + + @Autowired + public ApiClient(RestTemplate restTemplate) { + this.restTemplate = restTemplate; + init(); + } + + protected void init() { + // Use RFC3339 format for date and datetime. + // See http://xml2rfc.ietf.org/public/rfc/html/rfc3339.html#anchor14 + this.dateFormat = new RFC3339DateFormat(); + + // Use UTC as the default time zone. + this.dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); + + // Set default User-Agent. + setUserAgent("Java-SDK"); + + // Setup authentications (key: authentication name, value: authentication). + authentications = new HashMap();{{#authMethods}}{{#isBasic}} + authentications.put("{{name}}", new HttpBasicAuth());{{/isBasic}}{{#isApiKey}} + authentications.put("{{name}}", new ApiKeyAuth({{#isKeyInHeader}}"header"{{/isKeyInHeader}}{{^isKeyInHeader}}"query"{{/isKeyInHeader}}, "{{keyParamName}}"));{{/isApiKey}}{{#isOAuth}} + authentications.put("{{name}}", new OAuth());{{/isOAuth}}{{/authMethods}} + // Prevent the authentications from being modified. + authentications = Collections.unmodifiableMap(authentications); + } + + /** + * Get the current base path + * @return String the base path + */ + public String getBasePath() { + return basePath; + } + + /** + * Set the base path, which should include the host + * @param basePath the base path + * @return ApiClient this client + */ + public ApiClient setBasePath(String basePath) { + this.basePath = basePath; + return this; + } + + /** + * Get authentications (key: authentication name, value: authentication). + * @return Map the currently configured authentication types + */ + public Map getAuthentications() { + return authentications; + } + + /** + * Get authentication for the given name. + * + * @param authName The authentication name + * @return The authentication, null if not found + */ + public Authentication getAuthentication(String authName) { + return authentications.get(authName); + } + + /** + * Helper method to set username for the first HTTP basic authentication. + * @param username the username + */ + public void setUsername(String username) { + for (Authentication auth : authentications.values()) { + if (auth instanceof HttpBasicAuth) { + ((HttpBasicAuth) auth).setUsername(username); + return; + } + } + throw new RuntimeException("No HTTP basic authentication configured!"); + } + + /** + * Helper method to set password for the first HTTP basic authentication. + * @param password the password + */ + public void setPassword(String password) { + for (Authentication auth : authentications.values()) { + if (auth instanceof HttpBasicAuth) { + ((HttpBasicAuth) auth).setPassword(password); + return; + } + } + throw new RuntimeException("No HTTP basic authentication configured!"); + } + + /** + * Helper method to set API key value for the first API key authentication. + * @param apiKey the API key + */ + public void setApiKey(String apiKey) { + for (Authentication auth : authentications.values()) { + if (auth instanceof ApiKeyAuth) { + ((ApiKeyAuth) auth).setApiKey(apiKey); + return; + } + } + throw new RuntimeException("No API key authentication configured!"); + } + + /** + * Helper method to set API key prefix for the first API key authentication. + * @param apiKeyPrefix the API key prefix + */ + public void setApiKeyPrefix(String apiKeyPrefix) { + for (Authentication auth : authentications.values()) { + if (auth instanceof ApiKeyAuth) { + ((ApiKeyAuth) auth).setApiKeyPrefix(apiKeyPrefix); + return; + } + } + throw new RuntimeException("No API key authentication configured!"); + } + + /** + * Helper method to set access token for the first OAuth2 authentication. + * @param accessToken the access token + */ + public void setAccessToken(String accessToken) { + for (Authentication auth : authentications.values()) { + if (auth instanceof OAuth) { + ((OAuth) auth).setAccessToken(accessToken); + return; + } + } + throw new RuntimeException("No OAuth2 authentication configured!"); + } + + /** + * Set the User-Agent header's value (by adding to the default header map). + * @param userAgent the user agent string + * @return ApiClient this client + */ + public ApiClient setUserAgent(String userAgent) { + addDefaultHeader("User-Agent", userAgent); + return this; + } + + /** + * Add a default header. + * + * @param name The header's name + * @param value The header's value + * @return ApiClient this client + */ + public ApiClient addDefaultHeader(String name, String value) { + if (defaultHeaders.get(name) != null) { + defaultHeaders.remove(name); + } + defaultHeaders.add(name, value); + return this; + } + + public void setDebugging(boolean debugging) { + List currentInterceptors = this.restTemplate.getInterceptors(); + if(debugging) { + if (currentInterceptors == null) { + currentInterceptors = new ArrayList(); + } + ClientHttpRequestInterceptor interceptor = new ApiClientHttpRequestInterceptor(); + currentInterceptors.add(interceptor); + this.restTemplate.setInterceptors(currentInterceptors); + } else { + if (currentInterceptors != null && !currentInterceptors.isEmpty()) { + Iterator iter = currentInterceptors.iterator(); + while (iter.hasNext()) { + ClientHttpRequestInterceptor interceptor = iter.next(); + if (interceptor instanceof ApiClientHttpRequestInterceptor) { + iter.remove(); + } + } + this.restTemplate.setInterceptors(currentInterceptors); + } + } + this.debugging = debugging; + } + + /** + * Check that whether debugging is enabled for this API client. + * @return boolean true if this client is enabled for debugging, false otherwise + */ + public boolean isDebugging() { + return debugging; + } + + /** + * Get the date format used to parse/format date parameters. + * @return DateFormat format + */ + public DateFormat getDateFormat() { + return dateFormat; + } + + /** + * Set the date format used to parse/format date parameters. + * @param dateFormat Date format + * @return API client + */ + public ApiClient setDateFormat(DateFormat dateFormat) { + this.dateFormat = dateFormat; + {{#threetenbp}} + for(HttpMessageConverter converter:restTemplate.getMessageConverters()){ + if(converter instanceof AbstractJackson2HttpMessageConverter){ + ObjectMapper mapper = ((AbstractJackson2HttpMessageConverter)converter).getObjectMapper(); + mapper.setDateFormat(dateFormat); + } + } + {{/threetenbp}} + return this; + } + + /** + * Parse the given string into Date object. + */ + public Date parseDate(String str) { + try { + return dateFormat.parse(str); + } catch (ParseException e) { + throw new RuntimeException(e); + } + } + + /** + * Format the given Date object into string. + */ + public String formatDate(Date date) { + return dateFormat.format(date); + } + + /** + * Format the given parameter object into string. + * @param param the object to convert + * @return String the parameter represented as a String + */ + public String parameterToString(Object param) { + if (param == null) { + return ""; + } else if (param instanceof Date) { + return formatDate( (Date) param); + } else if (param instanceof Collection) { + StringBuilder b = new StringBuilder(); + for(Object o : (Collection) param) { + if(b.length() > 0) { + b.append(","); + } + b.append(String.valueOf(o)); + } + return b.toString(); + } else { + return String.valueOf(param); + } + } + + /** + * Converts a parameter to a {@link MultiValueMap} for use in REST requests + * @param collectionFormat The format to convert to + * @param name The name of the parameter + * @param value The parameter's value + * @return a Map containing the String value(s) of the input parameter + */ + public MultiValueMap parameterToMultiValueMap(CollectionFormat collectionFormat, String name, Object value) { + final MultiValueMap params = new LinkedMultiValueMap(); + + if (name == null || name.isEmpty() || value == null) { + return params; + } + + if(collectionFormat == null) { + collectionFormat = CollectionFormat.CSV; + } + + Collection valueCollection = null; + if (value instanceof Collection) { + valueCollection = (Collection) value; + } else { + params.add(name, parameterToString(value)); + return params; + } + + if (valueCollection.isEmpty()){ + return params; + } + + if (collectionFormat.equals(CollectionFormat.MULTI)) { + for (Object item : valueCollection) { + params.add(name, parameterToString(item)); + } + return params; + } + + List values = new ArrayList(); + for(Object o : valueCollection) { + values.add(parameterToString(o)); + } + params.add(name, collectionFormat.collectionToString(values)); + + return params; + } + + /** + * Check if the given {@code String} is a JSON MIME. + * @param mediaType the input MediaType + * @return boolean true if the MediaType represents JSON, false otherwise + */ + public boolean isJsonMime(String mediaType) { + // "* / *" is default to JSON + if ("*/*".equals(mediaType)) { + return true; + } + + try { + return isJsonMime(MediaType.parseMediaType(mediaType)); + } catch (InvalidMediaTypeException e) { + } + return false; + } + + /** + * Check if the given MIME is a JSON MIME. + * JSON MIME examples: + * application/json + * application/json; charset=UTF8 + * APPLICATION/JSON + * @param mediaType the input MediaType + * @return boolean true if the MediaType represents JSON, false otherwise + */ + public boolean isJsonMime(MediaType mediaType) { + return mediaType != null && (MediaType.APPLICATION_JSON.isCompatibleWith(mediaType) || mediaType.getSubtype().matches("^.*\\+json[;]?\\s*$")); + } + + /** + * Select the Accept header's value from the given accepts array: + * if JSON exists in the given array, use it; + * otherwise use all of them (joining into a string) + * + * @param accepts The accepts array to select from + * @return List The list of MediaTypes to use for the Accept header + */ + public List selectHeaderAccept(String[] accepts) { + if (accepts.length == 0) { + return null; + } + for (String accept : accepts) { + MediaType mediaType = MediaType.parseMediaType(accept); + if (isJsonMime(mediaType)) { + return Collections.singletonList(mediaType); + } + } + return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts)); + } + + /** + * Select the Content-Type header's value from the given array: + * if JSON exists in the given array, use it; + * otherwise use the first one of the array. + * + * @param contentTypes The Content-Type array to select from + * @return MediaType The Content-Type header to use. If the given array is empty, JSON will be used. + */ + public MediaType selectHeaderContentType(String[] contentTypes) { + if (contentTypes.length == 0) { + return MediaType.APPLICATION_JSON; + } + for (String contentType : contentTypes) { + MediaType mediaType = MediaType.parseMediaType(contentType); + if (isJsonMime(mediaType)) { + return mediaType; + } + } + return MediaType.parseMediaType(contentTypes[0]); + } + + /** + * Select the body to use for the request + * @param obj the body object + * @param formParams the form parameters + * @param contentType the content type of the request + * @return Object the selected body + */ + protected Object selectBody(Object obj, MultiValueMap formParams, MediaType contentType) { + boolean isForm = MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType) || MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType); + return isForm ? formParams : obj; + } + + /** + * Invoke API by sending HTTP request with the given options. + * + * @param the return type to use + * @param path The sub-path of the HTTP URL + * @param method The request method + * @param queryParams The query parameters + * @param body The request body object + * @param headerParams The header parameters + * @param formParams The form parameters + * @param accept The request's Accept header + * @param contentType The request's Content-Type header + * @param authNames The authentications to apply + * @param returnType The return type into which to deserialize the response + * @return ResponseEntity<T> The response of the chosen type + */ + public ResponseEntity invokeAPI(String path, HttpMethod method, MultiValueMap queryParams, Object body, HttpHeaders headerParams, MultiValueMap formParams, List accept, MediaType contentType, String[] authNames, ParameterizedTypeReference returnType) throws RestClientException { + updateParamsForAuth(authNames, queryParams, headerParams); + + final UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(basePath).path(path); + if (queryParams != null) { + builder.queryParams(queryParams); + } + + final BodyBuilder requestBuilder = RequestEntity.method(method, builder.build().toUri()); + if(accept != null) { + requestBuilder.accept(accept.toArray(new MediaType[accept.size()])); + } + if(contentType != null) { + requestBuilder.contentType(contentType); + } + + addHeadersToRequest(headerParams, requestBuilder); + addHeadersToRequest(defaultHeaders, requestBuilder); + + RequestEntity requestEntity = requestBuilder.body(selectBody(body, formParams, contentType)); + + ResponseEntity responseEntity = restTemplate.exchange(requestEntity, returnType); + + if (responseEntity.getStatusCode().is2xxSuccessful()) { + return responseEntity; + } else { + // The error handler built into the RestTemplate should handle 400 and 500 series errors. + throw new RestClientException("API returned " + responseEntity.getStatusCode() + " and it wasn't handled by the RestTemplate error handler"); + } + } + + /** + * Add headers to the request that is being built + * @param headers The headers to add + * @param requestBuilder The current request + */ + protected void addHeadersToRequest(HttpHeaders headers, BodyBuilder requestBuilder) { + for (Entry> entry : headers.headerSet()) { + List values = entry.getValue(); + for(String value : values) { + if (value != null) { + requestBuilder.header(entry.getKey(), value); + } + } + } + } + + /** + * Build the RestTemplate used to make HTTP requests. + * @return RestTemplate + */ + protected RestTemplate buildRestTemplate() { + {{#withXml}}List> messageConverters = new ArrayList>(); + messageConverters.add(new MappingJackson2HttpMessageConverter()); + XmlMapper xmlMapper = new XmlMapper(); + xmlMapper.configure(ToXmlGenerator.Feature.WRITE_XML_DECLARATION, true); + messageConverters.add(new MappingJackson2XmlHttpMessageConverter(xmlMapper)); + + RestTemplate restTemplate = new RestTemplate(messageConverters); + {{/withXml}}{{^withXml}}RestTemplate restTemplate = new RestTemplate();{{/withXml}} + {{#threetenbp}} + for(HttpMessageConverter converter:restTemplate.getMessageConverters()){ + if(converter instanceof AbstractJackson2HttpMessageConverter){ + ObjectMapper mapper = ((AbstractJackson2HttpMessageConverter)converter).getObjectMapper(); + ThreeTenModule module = new ThreeTenModule(); + module.addDeserializer(Instant.class, CustomInstantDeserializer.INSTANT); + module.addDeserializer(OffsetDateTime.class, CustomInstantDeserializer.OFFSET_DATE_TIME); + module.addDeserializer(ZonedDateTime.class, CustomInstantDeserializer.ZONED_DATE_TIME); + mapper.registerModule(module); + } + } + {{/threetenbp}} + // This allows us to read the response more than once - Necessary for debugging. + restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(restTemplate.getRequestFactory())); + return restTemplate; + } + + /** + * Update query and header parameters based on authentication settings. + * + * @param authNames The authentications to apply + * @param queryParams The query parameters + * @param headerParams The header parameters + */ + private void updateParamsForAuth(String[] authNames, MultiValueMap queryParams, HttpHeaders headerParams) { + for (String authName : authNames) { + Authentication auth = authentications.get(authName); + if (auth == null) { + throw new RestClientException("Authentication undefined: " + authName); + } + auth.applyToParams(queryParams, headerParams); + } + } + + private class ApiClientHttpRequestInterceptor implements ClientHttpRequestInterceptor { + private final Log log = LogFactory.getLog(ApiClientHttpRequestInterceptor.class); + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { + logRequest(request, body); + ClientHttpResponse response = execution.execute(request, body); + logResponse(response); + return response; + } + + private void logRequest(HttpRequest request, byte[] body) throws UnsupportedEncodingException { + log.info("URI: " + request.getURI()); + log.info("HTTP Method: " + request.getMethod()); + log.info("HTTP Headers: " + headersToString(request.getHeaders())); + log.info("Request Body: " + new String(body, StandardCharsets.UTF_8)); + } + + private void logResponse(ClientHttpResponse response) throws IOException { + log.info("HTTP Status Code: " + response.getStatusCode().value()); + log.info("Status Text: " + response.getStatusText()); + log.info("HTTP Headers: " + headersToString(response.getHeaders())); + log.info("Response Body: " + bodyToString(response.getBody())); + } + + private String headersToString(HttpHeaders headers) { + StringBuilder builder = new StringBuilder(); + for(Entry> entry : headers.headerSet()) { + builder.append(entry.getKey()).append("=["); + for(String value : entry.getValue()) { + builder.append(value).append(","); + } + builder.setLength(builder.length() - 1); // Get rid of trailing comma + builder.append("],"); + } + builder.setLength(builder.length() - 1); // Get rid of trailing comma + return builder.toString(); + } + + private String bodyToString(InputStream body) throws IOException { + StringBuilder builder = new StringBuilder(); + BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8)); + String line = bufferedReader.readLine(); + while (line != null) { + builder.append(line).append(System.lineSeparator()); + line = bufferedReader.readLine(); + } + bufferedReader.close(); + return builder.toString(); + } + } +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 070d4e4b5c6..51fcf131c67 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -145,7 +145,7 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he Consumer finalHeaders = h -> { h.setContentType(MediaType.APPLICATION_JSON); h.set(HTTP_USER_AGENT_HEADER, SPRING_AI_USER_AGENT); - h.addAll(headers); + headers.forEach(h::addAll); }; this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) @@ -204,7 +204,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest return this.restClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(additionalHttpHeader); + additionalHttpHeader.forEach(headers::addAll); addDefaultHeadersIfMissing(headers); }) .body(chatRequest) @@ -243,7 +243,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return this.webClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(additionalHttpHeader); + additionalHttpHeader.forEach(headers::addAll); addDefaultHeadersIfMissing(headers); }) // @formatter:on .body(Mono.just(chatRequest), ChatCompletionRequest.class) @@ -328,7 +328,8 @@ public ResponseEntity> embeddings(EmbeddingRequest< } private void addDefaultHeadersIfMissing(HttpHeaders headers) { - if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + List authorizationHeaders = headers.get(HttpHeaders.AUTHORIZATION); + if (authorizationHeaders == null && !(this.apiKey instanceof NoopApiKey)) { headers.setBearerAuth(this.apiKey.getValue()); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index cd89852d244..9c54aed73a9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -73,7 +73,7 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap authHeaders = h -> h.addAll(headers); + Consumer authHeaders = h -> headers.forEach(h::addAll); // @formatter:off this.restClient = restClientBuilder.clone() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java index 535b6a9a043..3525bb6ba43 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java @@ -51,7 +51,7 @@ public class OpenAiFileApi { public OpenAiFileApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - Consumer authHeaders = h -> h.addAll(headers); + Consumer authHeaders = h -> headers.forEach(h::addAll); this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index fe82e30e56b..795d0e22790 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -68,7 +68,7 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap { h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(headers); + headers.forEach(h::addAll); }) .defaultStatusHandler(responseErrorHandler) .defaultRequest(requestHeadersSpec -> { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index a90bd1aad8e..31021675a6f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -71,7 +71,7 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap { h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(headers); + headers.forEach(h::addAll); }) .defaultStatusHandler(responseErrorHandler) .defaultRequest(requestHeadersSpec -> { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index 7a4d344755d..b7ad61960a3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -71,22 +71,18 @@ public static RateLimit extractAiResponseHeaders(ResponseEntity response) { private static Duration getHeaderAsDuration(ResponseEntity response, String headerName) { var headers = response.getHeaders(); - if (headers.containsKey(headerName)) { - var values = headers.get(headerName); - if (!CollectionUtils.isEmpty(values)) { - return DurationFormatter.TIME_UNIT.parse(values.get(0)); - } + var values = headers.get(headerName); + if (!CollectionUtils.isEmpty(values)) { + return DurationFormatter.TIME_UNIT.parse(values.get(0)); } return null; } private static Long getHeaderAsLong(ResponseEntity response, String headerName) { var headers = response.getHeaders(); - if (headers.containsKey(headerName)) { - var values = headers.get(headerName); - if (!CollectionUtils.isEmpty(values)) { - return parseLong(headerName, values.get(0)); - } + var values = headers.get(headerName); + if (!CollectionUtils.isEmpty(values)) { + return parseLong(headerName, values.get(0)); } return null; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index f07d93a159c..4a8871382fb 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -174,7 +174,7 @@ private ZhiPuAiApi(String baseUrl, ApiKey apiKey, MultiValueMap Consumer authHeaders = h -> { h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(headers); + headers.forEach(h::addAll); }; this.restClient = restClientBuilder.clone() @@ -223,7 +223,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest return this.restClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(additionalHttpHeader); + additionalHttpHeader.forEach(headers::addAll); addDefaultHeadersIfMissing(headers); }) .body(chatRequest) @@ -260,7 +260,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return this.webClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(additionalHttpHeader); + additionalHttpHeader.forEach(headers::addAll); addDefaultHeadersIfMissing(headers); }) // @formatter:on .body(Mono.just(chatRequest), ChatCompletionRequest.class) @@ -330,7 +330,8 @@ public ResponseEntity> embeddings(EmbeddingRequest< } private void addDefaultHeadersIfMissing(HttpHeaders headers) { - if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + List authorizationHeaders = headers.get(HttpHeaders.AUTHORIZATION); + if (authorizationHeaders == null && !(this.apiKey instanceof NoopApiKey)) { headers.setBearerAuth(this.apiKey.getValue()); } } diff --git a/pom.xml b/pom.xml index 9f47b177abc..2ca60b4438a 100644 --- a/pom.xml +++ b/pom.xml @@ -270,6 +270,7 @@ 3.5.7 + 6.2.12 4.3.4 1.0.0-beta.16 1.1.0 @@ -978,6 +979,13 @@ pom import + + org.springframework + spring-framework-bom + ${spring-framework.version} + pom + import + io.modelcontextprotocol.sdk mcp-bom diff --git a/spring-ai-client-chat/src/main/kotlin/org/springframework/ai/chat/client/ChatClientExtensions.kt b/spring-ai-client-chat/src/main/kotlin/org/springframework/ai/chat/client/ChatClientExtensions.kt index 40a4c6ffd84..7128603bbc7 100644 --- a/spring-ai-client-chat/src/main/kotlin/org/springframework/ai/chat/client/ChatClientExtensions.kt +++ b/spring-ai-client-chat/src/main/kotlin/org/springframework/ai/chat/client/ChatClientExtensions.kt @@ -25,8 +25,8 @@ import org.springframework.core.ParameterizedTypeReference * @author Josh Long */ -inline fun ChatClient.CallResponseSpec.entity(): T = +inline fun ChatClient.CallResponseSpec.entity(): T = entity(object : ParameterizedTypeReference() {}) as T -inline fun ChatClient.CallResponseSpec.responseEntity(): ResponseEntity = +inline fun ChatClient.CallResponseSpec.responseEntity(): ResponseEntity = responseEntity(object : ParameterizedTypeReference() {}) diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java index 3dc125fd2a8..144a4fae399 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java @@ -57,7 +57,9 @@ public void handleError(URI url, HttpMethod method, @NonNull ClientHttpResponse handleError(response); } - @Override + // On purposes commented out so that the code can compile both with Spring 6 and + // Spring 7 + // @Override @SuppressWarnings("removal") public void handleError(@NonNull ClientHttpResponse response) throws IOException { if (response.getStatusCode().isError()) {