diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java new file mode 100644 index 000000000000..c981159b60d1 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -0,0 +1,618 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.net.URI; +import java.nio.charset.Charset; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInitializer; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.InterceptingClientHttpRequestFactory; +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriBuilderFactory; + +/** + * Default implementation of {@link RestClient}. + * + * @author Arjen Poutsma + * @since 6.1 + */ +final class DefaultRestClient implements RestClient { + + private static final Log logger = LogFactory.getLog(DefaultRestClient.class); + + private static final String URI_TEMPLATE_ATTRIBUTE = RestClient.class.getName() + ".uriTemplate"; + + + private final ClientHttpRequestFactory clientRequestFactory; + + @Nullable + private volatile ClientHttpRequestFactory interceptingRequestFactory; + + @Nullable + private final List initializers; + + @Nullable + private final List interceptors; + + private final UriBuilderFactory uriBuilderFactory; + + @Nullable + private final HttpHeaders defaultHeaders; + + private final List defaultStatusHandlers; + + private final DefaultRestClientBuilder builder; + + private final List> messageConverters; + + + DefaultRestClient(ClientHttpRequestFactory clientRequestFactory, @Nullable List interceptors, + @Nullable List initializers, + UriBuilderFactory uriBuilderFactory, + @Nullable HttpHeaders defaultHeaders, + @Nullable List statusHandlers, + List> messageConverters, + DefaultRestClientBuilder builder) { + + this.clientRequestFactory = clientRequestFactory; + this.initializers = initializers; + this.interceptors = interceptors; + this.uriBuilderFactory = uriBuilderFactory; + this.defaultHeaders = defaultHeaders; + this.defaultStatusHandlers = (statusHandlers != null) ? new ArrayList<>(statusHandlers) : new ArrayList<>(); + this.messageConverters = messageConverters; + this.builder = builder; + } + + @Override + public RequestHeadersUriSpec get() { + return methodInternal(HttpMethod.GET); + } + + @Override + public RequestHeadersUriSpec head() { + return methodInternal(HttpMethod.HEAD); + } + + @Override + public RequestBodyUriSpec post() { + return methodInternal(HttpMethod.POST); + } + + @Override + public RequestBodyUriSpec put() { + return methodInternal(HttpMethod.PUT); + } + + @Override + public RequestBodyUriSpec patch() { + return methodInternal(HttpMethod.PATCH); + } + + @Override + public RequestHeadersUriSpec delete() { + return methodInternal(HttpMethod.DELETE); + } + + @Override + public RequestHeadersUriSpec options() { + return methodInternal(HttpMethod.OPTIONS); + } + + @Override + public RequestBodyUriSpec method(HttpMethod method) { + Assert.notNull(method, "Method must not be null"); + return methodInternal(method); + } + + private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { + return new DefaultRequestBodyUriSpec(httpMethod); + } + + @Override + public Builder mutate() { + return new DefaultRestClientBuilder(this.builder); + } + + + private class DefaultRequestBodyUriSpec implements RequestBodyUriSpec { + + private final HttpMethod httpMethod; + + @Nullable + private URI uri; + + @Nullable + private HttpHeaders headers; + + @Nullable + private InternalBody body; + + private final Map attributes = new LinkedHashMap<>(4); + + @Nullable + private Consumer httpRequestConsumer; + + public DefaultRequestBodyUriSpec(HttpMethod httpMethod) { + this.httpMethod = httpMethod; + } + + + @Override + public RequestBodySpec uri(String uriTemplate, Object... uriVariables) { + attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate); + return uri(DefaultRestClient.this.uriBuilderFactory.expand(uriTemplate, uriVariables)); + } + + @Override + public RequestBodySpec uri(String uriTemplate, Map uriVariables) { + attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate); + return uri(DefaultRestClient.this.uriBuilderFactory.expand(uriTemplate, uriVariables)); + } + + @Override + public RequestBodySpec uri(String uriTemplate, Function uriFunction) { + attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate); + return uri(uriFunction.apply(DefaultRestClient.this.uriBuilderFactory.uriString(uriTemplate))); + } + + @Override + public RequestBodySpec uri(Function uriFunction) { + return uri(uriFunction.apply(DefaultRestClient.this.uriBuilderFactory.builder())); + } + + @Override + public RequestBodySpec uri(URI uri) { + this.uri = uri; + return this; + } + + private HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + } + return this.headers; + } + + @Override + public DefaultRequestBodyUriSpec header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + getHeaders().add(headerName, headerValue); + } + return this; + } + + @Override + public DefaultRequestBodyUriSpec headers(Consumer headersConsumer) { + headersConsumer.accept(getHeaders()); + return this; + } + + @Override + public DefaultRequestBodyUriSpec accept(MediaType... acceptableMediaTypes) { + getHeaders().setAccept(Arrays.asList(acceptableMediaTypes)); + return this; + } + + @Override + public DefaultRequestBodyUriSpec acceptCharset(Charset... acceptableCharsets) { + getHeaders().setAcceptCharset(Arrays.asList(acceptableCharsets)); + return this; + } + + @Override + public DefaultRequestBodyUriSpec contentType(MediaType contentType) { + getHeaders().setContentType(contentType); + return this; + } + + @Override + public DefaultRequestBodyUriSpec contentLength(long contentLength) { + getHeaders().setContentLength(contentLength); + return this; + } + + @Override + public DefaultRequestBodyUriSpec ifModifiedSince(ZonedDateTime ifModifiedSince) { + getHeaders().setIfModifiedSince(ifModifiedSince); + return this; + } + + @Override + public DefaultRequestBodyUriSpec ifNoneMatch(String... ifNoneMatches) { + getHeaders().setIfNoneMatch(Arrays.asList(ifNoneMatches)); + return this; + } + + @Override + public RequestBodySpec attribute(String name, Object value) { + this.attributes.put(name, value); + return this; + } + + @Override + public RequestBodySpec attributes(Consumer> attributesConsumer) { + attributesConsumer.accept(this.attributes); + return this; + } + + @Override + public RequestBodySpec httpRequest(Consumer requestConsumer) { + this.httpRequestConsumer = (this.httpRequestConsumer != null ? + this.httpRequestConsumer.andThen(requestConsumer) : requestConsumer); + return this; + } + + @Override + public RequestBodySpec body(Object body) { + this.body = clientHttpRequest -> writeWithMessageConverters(body, body.getClass(), clientHttpRequest); + return this; + } + + @Override + public RequestBodySpec body(T body, ParameterizedTypeReference bodyType) { + this.body = clientHttpRequest -> writeWithMessageConverters(body, bodyType.getType(), clientHttpRequest); + return this; + } + + @Override + public RequestBodySpec body(StreamingHttpOutputMessage.Body body) { + this.body = request -> body.writeTo(request.getBody()); + return this; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private void writeWithMessageConverters(Object body, Type bodyType, ClientHttpRequest clientRequest) + throws IOException { + + MediaType contentType = clientRequest.getHeaders().getContentType(); + Class bodyClass = body.getClass(); + + for (HttpMessageConverter messageConverter : DefaultRestClient.this.messageConverters) { + if (messageConverter instanceof GenericHttpMessageConverter genericMessageConverter) { + if (genericMessageConverter.canWrite(bodyType, bodyClass, contentType)) { + logBody(body, contentType, genericMessageConverter); + genericMessageConverter.write(body, bodyType, contentType, clientRequest); + return; + } + } + if (messageConverter.canWrite(bodyClass, contentType)) { + logBody(body, contentType, messageConverter); + messageConverter.write(body, contentType, clientRequest); + return; + } + } + String message = "No HttpMessageConverter for " + bodyClass.getName(); + if (contentType != null) { + message += " and content type \"" + contentType + "\""; + } + throw new RestClientException(message); + } + + private void logBody(Object body, @Nullable MediaType mediaType, HttpMessageConverter converter) { + if (logger.isDebugEnabled()) { + StringBuilder msg = new StringBuilder("Writing ["); + msg.append(body); + msg.append("] "); + if (mediaType != null) { + msg.append("as \""); + msg.append(mediaType); + msg.append("\" "); + } + msg.append("with "); + msg.append(converter.getClass().getName()); + logger.debug(msg.toString()); + } + } + + + @Override + public ResponseSpec retrieve() { + return exchangeInternal(DefaultResponseSpec::new, false); + } + + @Override + public T exchange(ExchangeFunction exchangeFunction) { + return exchangeInternal(exchangeFunction, true); + } + + private T exchangeInternal(ExchangeFunction exchangeFunction, boolean close) { + Assert.notNull(exchangeFunction, "ExchangeFunction must not be null"); + + ClientHttpResponse clientResponse = null; + URI uri = null; + try { + uri = initUri(); + HttpHeaders headers = initHeaders(); + ClientHttpRequest clientRequest = createRequest(uri); + clientRequest.getHeaders().addAll(headers); + if (this.body != null) { + this.body.writeTo(clientRequest); + } + if (this.httpRequestConsumer != null) { + this.httpRequestConsumer.accept(clientRequest); + } + clientResponse = clientRequest.execute(); + return exchangeFunction.exchange(clientRequest, clientResponse); + } + catch (IOException ex) { + throw createResourceAccessException(uri, this.httpMethod, ex); + } + finally { + if (close && clientResponse != null) { + clientResponse.close(); + } + } + } + + private URI initUri() { + return (this.uri != null ? this.uri : DefaultRestClient.this.uriBuilderFactory.expand("")); + } + + private HttpHeaders initHeaders() { + HttpHeaders defaultHeaders = DefaultRestClient.this.defaultHeaders; + if (CollectionUtils.isEmpty(this.headers)) { + return (defaultHeaders != null ? defaultHeaders : new HttpHeaders()); + } + else if (CollectionUtils.isEmpty(defaultHeaders)) { + return this.headers; + } + else { + HttpHeaders result = new HttpHeaders(); + result.putAll(defaultHeaders); + result.putAll(this.headers); + return result; + } + } + + private ClientHttpRequest createRequest(URI uri) throws IOException { + ClientHttpRequestFactory factory; + if (DefaultRestClient.this.interceptors != null) { + factory = DefaultRestClient.this.interceptingRequestFactory; + if (factory == null) { + factory = new InterceptingClientHttpRequestFactory(DefaultRestClient.this.clientRequestFactory, DefaultRestClient.this.interceptors); + DefaultRestClient.this.interceptingRequestFactory = factory; + } + } + else { + factory = DefaultRestClient.this.clientRequestFactory; + } + ClientHttpRequest request = factory.createRequest(uri, this.httpMethod); + if (DefaultRestClient.this.initializers != null) { + DefaultRestClient.this.initializers.forEach(initializer -> initializer.initialize(request)); + } + return request; + } + + private static ResourceAccessException createResourceAccessException(URI url, HttpMethod method, IOException ex) { + StringBuilder msg = new StringBuilder("I/O error on "); + msg.append(method.name()); + msg.append(" request for \""); + String urlString = url.toString(); + int idx = urlString.indexOf('?'); + if (idx != -1) { + msg.append(urlString, 0, idx); + } + else { + msg.append(urlString); + } + msg.append("\": "); + msg.append(ex.getMessage()); + return new ResourceAccessException(msg.toString(), ex); + } + + + + @FunctionalInterface + private interface InternalBody { + + void writeTo(ClientHttpRequest request) throws IOException; + } + } + + private class DefaultResponseSpec implements ResponseSpec { + + private final HttpRequest clientRequest; + + private final ClientHttpResponse clientResponse; + + private final List statusHandlers = new ArrayList<>(1); + + private final int defaultStatusHandlerCount; + + + DefaultResponseSpec(HttpRequest clientRequest, ClientHttpResponse clientResponse) { + this.clientRequest = clientRequest; + this.clientResponse = clientResponse; + this.statusHandlers.addAll(DefaultRestClient.this.defaultStatusHandlers); + this.statusHandlers.add(StatusHandler.defaultHandler(DefaultRestClient.this.messageConverters)); + this.defaultStatusHandlerCount = this.statusHandlers.size(); + } + + @Override + public ResponseSpec onStatus(Predicate statusPredicate, ErrorHandler errorHandler) { + Assert.notNull(statusPredicate, "StatusPredicate must not be null"); + Assert.notNull(errorHandler, "ErrorHandler must not be null"); + + return onStatusInternal(StatusHandler.of(statusPredicate, errorHandler)); + } + + @Override + public ResponseSpec onStatus(ResponseErrorHandler errorHandler) { + Assert.notNull(errorHandler, "ErrorHandler must not be null"); + + return onStatusInternal(StatusHandler.fromErrorHandler(errorHandler)); + } + + private ResponseSpec onStatusInternal(StatusHandler statusHandler) { + Assert.notNull(statusHandler, "StatusHandler must not be null"); + + int index = this.statusHandlers.size() - this.defaultStatusHandlerCount; // Default handlers always last + this.statusHandlers.add(index, statusHandler); + return this; + } + + @Override + public T body(Class bodyType) { + return readWithMessageConverters(bodyType, bodyType); + } + + @Override + public T body(ParameterizedTypeReference bodyType) { + Type type = bodyType.getType(); + Class bodyClass = bodyClass(type); + return readWithMessageConverters(type, bodyClass); + } + + @Override + public ResponseEntity toEntity(Class bodyType) { + return toEntityInternal(bodyType, bodyType); + } + + @Override + public ResponseEntity toEntity(ParameterizedTypeReference bodyType) { + Type type = bodyType.getType(); + Class bodyClass = bodyClass(type); + return toEntityInternal(type, bodyClass); + } + + private ResponseEntity toEntityInternal(Type bodyType, Class bodyClass) { + T body = readWithMessageConverters(bodyType, bodyClass); + try { + return ResponseEntity.status(this.clientResponse.getStatusCode()) + .headers(this.clientResponse.getHeaders()) + .body(body); + } + catch (IOException ex) { + throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex); + } + } + + @Override + public ResponseEntity toBodilessEntity() { + try (this.clientResponse) { + applyStatusHandlers(this.clientRequest, this.clientResponse); + return ResponseEntity.status(this.clientResponse.getStatusCode()) + .headers(this.clientResponse.getHeaders()) + .build(); + } + catch (IOException ex) { + throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex); + } + } + + + @SuppressWarnings("unchecked") + private static Class bodyClass(Type type) { + if (type instanceof Class clazz) { + return (Class) clazz; + } + if (type instanceof ParameterizedType parameterizedType && + parameterizedType.getRawType() instanceof Class rawType) { + return (Class) rawType; + } + return (Class) Object.class; + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + private T readWithMessageConverters(Type bodyType, Class bodyClass) { + MediaType contentType = getContentType(); + + try (this.clientResponse) { + applyStatusHandlers(this.clientRequest, this.clientResponse); + + for (HttpMessageConverter messageConverter : DefaultRestClient.this.messageConverters) { + if (messageConverter instanceof GenericHttpMessageConverter genericHttpMessageConverter) { + if (genericHttpMessageConverter.canRead(bodyType, bodyClass, contentType)) { + if (logger.isDebugEnabled()) { + logger.debug("Reading to [" + ResolvableType.forType(bodyType) + "]"); + } + return (T) genericHttpMessageConverter.read(bodyType, bodyClass, this.clientResponse); + } + } + if (messageConverter.canRead(bodyClass, contentType)) { + if (logger.isDebugEnabled()) { + logger.debug("Reading to [" + bodyClass.getName() + "] as \"" + contentType + "\""); + } + return (T) messageConverter.read((Class)bodyClass, this.clientResponse); + } + } + throw new UnknownContentTypeException(bodyType, contentType, + this.clientResponse.getStatusCode(), this.clientResponse.getStatusText(), + this.clientResponse.getHeaders(), RestClientUtils.getBody(this.clientResponse)); + } + catch (IOException | HttpMessageNotReadableException ex) { + throw new RestClientException("Error while extracting response for type [" + + ResolvableType.forType(bodyType) + "] and content type [" + contentType + "]", ex); + } + } + + private MediaType getContentType() { + MediaType contentType = this.clientResponse.getHeaders().getContentType(); + if (contentType == null) { + contentType = MediaType.APPLICATION_OCTET_STREAM; + } + return contentType; + } + + private void applyStatusHandlers(HttpRequest request, ClientHttpResponse response) throws IOException { + for (StatusHandler handler : this.statusHandlers) { + if (handler.test(response)) { + handler.handle(request, response); + return; + } + } + } + + + } +} diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java new file mode 100644 index 000000000000..4191770fca49 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java @@ -0,0 +1,375 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInitializer; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.http.converter.ByteArrayHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.ResourceHttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.http.converter.cbor.MappingJackson2CborHttpMessageConverter; +import org.springframework.http.converter.json.GsonHttpMessageConverter; +import org.springframework.http.converter.json.JsonbHttpMessageConverter; +import org.springframework.http.converter.json.KotlinSerializationJsonHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.http.converter.smile.MappingJackson2SmileHttpMessageConverter; +import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriBuilderFactory; + +/** + * Default implementation of {@link RestClient.Builder}. + * + * @author Arjen Poutsma + * @since 6.1 + */ +final class DefaultRestClientBuilder implements RestClient.Builder { + + private static final boolean httpComponentsClientPresent; + + private static final boolean jackson2Present; + + private static final boolean gsonPresent; + + private static final boolean jsonbPresent; + + private static final boolean kotlinSerializationJsonPresent; + + private static final boolean jackson2SmilePresent; + + private static final boolean jackson2CborPresent; + + + static { + ClassLoader loader = DefaultRestClientBuilder.class.getClassLoader(); + httpComponentsClientPresent = ClassUtils.isPresent("org.apache.hc.client5.http.classic.HttpClient", loader); + jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", loader) && + ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", loader); + gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", loader); + jsonbPresent = ClassUtils.isPresent("jakarta.json.bind.Jsonb", loader); + kotlinSerializationJsonPresent = ClassUtils.isPresent("kotlinx.serialization.json.Json", loader); + jackson2SmilePresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.smile.SmileFactory", loader); + jackson2CborPresent = ClassUtils.isPresent("com.fasterxml.jackson.dataformat.cbor.CBORFactory", loader); + } + + @Nullable + private String baseUrl; + + @Nullable + private Map defaultUriVariables; + + @Nullable + private UriBuilderFactory uriBuilderFactory; + + @Nullable + private HttpHeaders defaultHeaders; + + @Nullable + private Consumer> defaultRequest; + + @Nullable + private List statusHandlers; + + @Nullable + private ClientHttpRequestFactory requestFactory; + + @Nullable + private List> messageConverters; + + @Nullable + private List interceptors; + + @Nullable + private List initializers; + + + public DefaultRestClientBuilder() { + } + + public DefaultRestClientBuilder(DefaultRestClientBuilder other) { + Assert.notNull(other, "Other must not be null"); + + this.baseUrl = other.baseUrl; + this.defaultUriVariables = (other.defaultUriVariables != null ? + new LinkedHashMap<>(other.defaultUriVariables) : null); + this.uriBuilderFactory = other.uriBuilderFactory; + + if (other.defaultHeaders != null) { + this.defaultHeaders = new HttpHeaders(); + this.defaultHeaders.putAll(other.defaultHeaders); + } + else { + this.defaultHeaders = null; + } + this.defaultRequest = other.defaultRequest; + this.statusHandlers = (other.statusHandlers != null ? new ArrayList<>(other.statusHandlers) : null); + + this.requestFactory = other.requestFactory; + this.messageConverters = (other.messageConverters != null ? + new ArrayList<>(other.messageConverters) : null); + + this.interceptors = (other.interceptors != null) ? new ArrayList<>(other.interceptors) : null; + this.initializers = (other.initializers != null) ? new ArrayList<>(other.initializers) : null; + } + + public DefaultRestClientBuilder(RestTemplate restTemplate) { + Assert.notNull(restTemplate, "RestTemplate must not be null"); + + if (restTemplate.getUriTemplateHandler() instanceof UriBuilderFactory builderFactory) { + this.uriBuilderFactory = builderFactory; + } + this.statusHandlers = new ArrayList<>(); + this.statusHandlers.add(StatusHandler.fromErrorHandler(restTemplate.getErrorHandler())); + + this.requestFactory = restTemplate.getRequestFactory(); + this.messageConverters = new ArrayList<>(restTemplate.getMessageConverters()); + + if (!CollectionUtils.isEmpty(restTemplate.getInterceptors())) { + this.interceptors = new ArrayList<>(restTemplate.getInterceptors()); + } + if (!CollectionUtils.isEmpty(restTemplate.getClientHttpRequestInitializers())) { + this.initializers = new ArrayList<>(restTemplate.getClientHttpRequestInitializers()); + } + } + + + @Override + public RestClient.Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + @Override + public RestClient.Builder defaultUriVariables(Map defaultUriVariables) { + this.defaultUriVariables = defaultUriVariables; + return this; + } + + @Override + public RestClient.Builder uriBuilderFactory(UriBuilderFactory uriBuilderFactory) { + this.uriBuilderFactory = uriBuilderFactory; + return this; + } + + @Override + public RestClient.Builder defaultHeader(String header, String... values) { + initHeaders().put(header, Arrays.asList(values)); + return this; + } + + @Override + public RestClient.Builder defaultHeaders(Consumer headersConsumer) { + headersConsumer.accept(initHeaders()); + return this; + } + + private HttpHeaders initHeaders() { + if (this.defaultHeaders == null) { + this.defaultHeaders = new HttpHeaders(); + } + return this.defaultHeaders; + } + + @Override + public RestClient.Builder defaultRequest(Consumer> defaultRequest) { + this.defaultRequest = this.defaultRequest != null ? + this.defaultRequest.andThen(defaultRequest) : defaultRequest; + return this; + } + + @Override + public RestClient.Builder defaultStatusHandler(Predicate statusPredicate, RestClient.ResponseSpec.ErrorHandler errorHandler) { + return defaultStatusHandlerInternal(StatusHandler.of(statusPredicate, errorHandler)); + } + + @Override + public RestClient.Builder defaultStatusHandler(ResponseErrorHandler errorHandler) { + return defaultStatusHandlerInternal(StatusHandler.fromErrorHandler(errorHandler)); + } + + private RestClient.Builder defaultStatusHandlerInternal(StatusHandler statusHandler) { + if (this.statusHandlers == null) { + this.statusHandlers = new ArrayList<>(); + } + this.statusHandlers.add(statusHandler); + return this; + } + + @Override + public RestClient.Builder requestInterceptor(ClientHttpRequestInterceptor interceptor) { + Assert.notNull(interceptor, "Interceptor must not be null"); + initInterceptors().add(interceptor); + return this; + } + + @Override + public RestClient.Builder requestInterceptors(Consumer> interceptorsConsumer) { + interceptorsConsumer.accept(initInterceptors()); + return this; + } + + private List initInterceptors() { + if (this.interceptors == null) { + this.interceptors = new ArrayList<>(); + } + return this.interceptors; + } + + @Override + public RestClient.Builder requestInitializer(ClientHttpRequestInitializer initializer) { + Assert.notNull(initializer, "Initializer must not be null"); + initInitializers().add(initializer); + return this; + } + + @Override + public RestClient.Builder requestInitializers(Consumer> initializersConsumer) { + initializersConsumer.accept(initInitializers()); + return this; + } + + private List initInitializers() { + if (this.initializers == null) { + this.initializers = new ArrayList<>(); + } + return this.initializers; + } + + + @Override + public RestClient.Builder requestFactory(ClientHttpRequestFactory requestFactory) { + this.requestFactory = requestFactory; + return this; + } + + @Override + public RestClient.Builder messageConverters(Consumer>> configurer) { + configurer.accept(initMessageConverters()); + return this; + } + + @Override + public RestClient.Builder apply(Consumer builderConsumer) { + builderConsumer.accept(this); + return this; + } + + private List> initMessageConverters() { + if (this.messageConverters == null) { + this.messageConverters = new ArrayList<>(); + this.messageConverters.add(new ByteArrayHttpMessageConverter()); + this.messageConverters.add(new StringHttpMessageConverter()); + this.messageConverters.add(new ResourceHttpMessageConverter(false)); + this.messageConverters.add(new AllEncompassingFormHttpMessageConverter()); + + if (kotlinSerializationJsonPresent) { + this.messageConverters.add(new KotlinSerializationJsonHttpMessageConverter()); + } + if (jackson2Present) { + this.messageConverters.add(new MappingJackson2HttpMessageConverter()); + } + else if (gsonPresent) { + this.messageConverters.add(new GsonHttpMessageConverter()); + } + else if (jsonbPresent) { + this.messageConverters.add(new JsonbHttpMessageConverter()); + } + if (jackson2SmilePresent) { + this.messageConverters.add(new MappingJackson2SmileHttpMessageConverter()); + } + if (jackson2CborPresent) { + this.messageConverters.add(new MappingJackson2CborHttpMessageConverter()); + } + } + return this.messageConverters; + } + + + @Override + public RestClient.Builder clone() { + return new DefaultRestClientBuilder(this); + } + + @Override + public RestClient build() { + ClientHttpRequestFactory requestFactory = initRequestFactory(); + UriBuilderFactory uriBuilderFactory = initUriBuilderFactory(); + HttpHeaders defaultHeaders = copyDefaultHeaders(); + List> messageConverters = (this.messageConverters != null ? + this.messageConverters : initMessageConverters()); + return new DefaultRestClient(requestFactory, + this.interceptors, this.initializers, uriBuilderFactory, + defaultHeaders, + this.statusHandlers, + messageConverters, + new DefaultRestClientBuilder(this) + ); + } + + private ClientHttpRequestFactory initRequestFactory() { + if (this.requestFactory != null) { + return this.requestFactory; + } + else if (httpComponentsClientPresent) { + return new HttpComponentsClientHttpRequestFactory(); + } + else { + return new SimpleClientHttpRequestFactory(); + } + } + + private UriBuilderFactory initUriBuilderFactory() { + if (this.uriBuilderFactory != null) { + return this.uriBuilderFactory; + } + DefaultUriBuilderFactory factory = (this.baseUrl != null ? + new DefaultUriBuilderFactory(this.baseUrl) : new DefaultUriBuilderFactory()); + factory.setDefaultUriVariables(this.defaultUriVariables); + return factory; + } + + @Nullable + private HttpHeaders copyDefaultHeaders() { + if (this.defaultHeaders != null) { + HttpHeaders copy = new HttpHeaders(); + this.defaultHeaders.forEach((key, values) -> copy.put(key, new ArrayList<>(values))); + return HttpHeaders.readOnlyHttpHeaders(copy); + } + else { + return null; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClient.java b/spring-web/src/main/java/org/springframework/web/client/RestClient.java new file mode 100644 index 000000000000..eedc1b2a79e6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RestClient.java @@ -0,0 +1,751 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.time.ZonedDateTime; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInitializer; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.web.util.DefaultUriBuilderFactory; +import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriBuilderFactory; + +/** + * Client to perform HTTP requests, exposing a fluent, synchronous API over + * underlying HTTP client libraries such the JDK {@code HttpClient}, Apache + * HttpComponents, and others. + * + *

Use static factory methods {@link #create()}, {@link #create(String)}, + * or {@link RestClient#builder()} to prepare an instance. To use the same + * configuration as a {@link RestTemplate}, use {@link #create(RestTemplate)} or + * {@link #builder(RestTemplate)}. + * + *

For examples with a response body see: + *

    + *
  • {@link RequestHeadersSpec#retrieve() retrieve()} + *
  • {@link RequestHeadersSpec#exchange(RequestHeadersSpec.ExchangeFunction) exchange(Function<ClientHttpRequest, T>)} + *
+ *

For examples with a request body see: + *

    + *
  • {@link RequestBodySpec#body(Object) body(Object)} + *
  • {@link RequestBodySpec#body(Object, ParameterizedTypeReference) body(Object, ParameterizedTypeReference)} + *
  • {@link RequestBodySpec#body(StreamingHttpOutputMessage.Body) body(Consumer<OutputStream>} + *
+ * + * @author Arjen Poutsma + * @since 6.1 + */ +public interface RestClient { + + /** + * Start building an HTTP GET request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec get(); + + /** + * Start building an HTTP HEAD request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec head(); + + /** + * Start building an HTTP POST request. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec post(); + + /** + * Start building an HTTP PUT request. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec put(); + + /** + * Start building an HTTP PATCH request. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec patch(); + + /** + * Start building an HTTP DELETE request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec delete(); + + /** + * Start building an HTTP OPTIONS request. + * @return a spec for specifying the target URL + */ + RequestHeadersUriSpec options(); + + /** + * Start building a request for the given {@code HttpMethod}. + * @return a spec for specifying the target URL + */ + RequestBodyUriSpec method(HttpMethod method); + + + /** + * Return a builder to create a new {@code RestClient} whose settings are + * replicated from the current {@code RestClient}. + */ + Builder mutate(); + + + // Static, factory methods + + /** + * Create a new {@code RestClient}. + * @see #create(String) + * @see #builder() + */ + static RestClient create() { + return new DefaultRestClientBuilder().build(); + } + + /** + * Variant of {@link #create()} that accepts a default base URL. For more + * details see {@link Builder#baseUrl(String) Builder.baseUrl(String)}. + * @param baseUrl the base URI for all requests + * @see #builder() + */ + static RestClient create(String baseUrl) { + return new DefaultRestClientBuilder().baseUrl(baseUrl).build(); + } + + /** + * Create a new {@code RestClient} based on the configuration of the + * given {@code RestTemplate}. The returned builder is configured with the + * template's + *
    + *
  • {@link RestTemplate#getRequestFactory() ClientHttpRequestFactory},
  • + *
  • {@link RestTemplate#getMessageConverters() HttpMessageConverters},
  • + *
  • {@link RestTemplate#getInterceptors() ClientHttpRequestInterceptors},
  • + *
  • {@link RestTemplate#getClientHttpRequestInitializers() ClientHttpRequestInitializers},
  • + *
  • {@link RestTemplate#getUriTemplateHandler() UriBuilderFactory}, and
  • + *
  • {@linkplain RestTemplate#getErrorHandler() error handler}.
  • + *
+ * @param restTemplate the rest template to base the returned client's + * configuration on + * @return a {@code RestClient} initialized with the {@code restTemplate}'s + * configuration + */ + static RestClient create(RestTemplate restTemplate) { + return new DefaultRestClientBuilder(restTemplate).build(); + } + + /** + * Obtain a {@code RestClient} builder. + */ + static RestClient.Builder builder() { + return new DefaultRestClientBuilder(); + } + + /** + * Obtain a {@code RestClient} builder based on the configuration of the + * given {@code RestTemplate}. The returned builder is configured with the + * template's + *
    + *
  • {@link RestTemplate#getRequestFactory() ClientHttpRequestFactory},
  • + *
  • {@link RestTemplate#getMessageConverters() HttpMessageConverters},
  • + *
  • {@link RestTemplate#getInterceptors() ClientHttpRequestInterceptors},
  • + *
  • {@link RestTemplate#getClientHttpRequestInitializers() ClientHttpRequestInitializers},
  • + *
  • {@link RestTemplate#getUriTemplateHandler() UriBuilderFactory}, and
  • + *
  • {@linkplain RestTemplate#getErrorHandler() error handler}.
  • + *
+ * @param restTemplate the rest template to base the returned builder's + * configuration on + * @return a {@code RestClient} builder initialized with {@code restTemplate}'s + * configuration + */ + static RestClient.Builder builder(RestTemplate restTemplate) { + return new DefaultRestClientBuilder(restTemplate); + } + + + /** + * A mutable builder for creating a {@link RestClient}. + */ + interface Builder { + + /** + * Configure a base URL for requests. Effectively a shortcut for: + *

+ *

+		 * String baseUrl = "https://abc.go.com/v1";
+		 * DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory(baseUrl);
+		 * RestClient client = RestClient.builder().uriBuilderFactory(factory).build();
+		 * 
+ *

The {@code DefaultUriBuilderFactory} is used to prepare the URL + * for every request with the given base URL, unless the URL request + * for a given URL is absolute in which case the base URL is ignored. + *

Note: this method is mutually exclusive with + * {@link #uriBuilderFactory(UriBuilderFactory)}. If both are used, the + * baseUrl value provided here will be ignored. + * @see DefaultUriBuilderFactory#DefaultUriBuilderFactory(String) + * @see #uriBuilderFactory(UriBuilderFactory) + */ + Builder baseUrl(String baseUrl); + + /** + * Configure default URL variable values to use when expanding URI + * templates with a {@link Map}. Effectively a shortcut for: + *

+ *

+		 * Map<String, ?> defaultVars = ...;
+		 * DefaultUriBuilderFactory factory = new DefaultUriBuilderFactory();
+		 * factory.setDefaultVariables(defaultVars);
+		 * RestClient client = RestClient.builder().uriBuilderFactory(factory).build();
+		 * 
+ *

Note: this method is mutually exclusive with + * {@link #uriBuilderFactory(UriBuilderFactory)}. If both are used, the + * defaultUriVariables value provided here will be ignored. + * @see DefaultUriBuilderFactory#setDefaultUriVariables(Map) + * @see #uriBuilderFactory(UriBuilderFactory) + */ + Builder defaultUriVariables(Map defaultUriVariables); + + /** + * Provide a pre-configured {@link UriBuilderFactory} instance. This is + * an alternative to, and effectively overrides the following shortcut + * properties: + *

    + *
  • {@link #baseUrl(String)} + *
  • {@link #defaultUriVariables(Map)}. + *
+ * @param uriBuilderFactory the URI builder factory to use + * @see #baseUrl(String) + * @see #defaultUriVariables(Map) + */ + Builder uriBuilderFactory(UriBuilderFactory uriBuilderFactory); + + /** + * Global option to specify a header to be added to every request, + * if the request does not already contain such a header. + * @param header the header name + * @param values the header values + */ + Builder defaultHeader(String header, String... values); + + /** + * Provides access to every {@link #defaultHeader(String, String...)} + * declared so far with the possibility to add, replace, or remove. + * @param headersConsumer the consumer + */ + Builder defaultHeaders(Consumer headersConsumer); + + /** + * Provide a consumer to customize every request being built. + * @param defaultRequest the consumer to use for modifying requests + */ + Builder defaultRequest(Consumer> defaultRequest); + + /** + * Register a default + * {@linkplain ResponseSpec#onStatus(Predicate, ResponseSpec.ErrorHandler) status handler} + * to apply to every response. Such default handlers are applied in the + * order in which they are registered, and after any others that are + * registered for a specific response. + * @param statusPredicate to match responses with + * @param errorHandler handler that typically, though not necessarily, + * throws an exception + * @return this builder + */ + Builder defaultStatusHandler(Predicate statusPredicate, + ResponseSpec.ErrorHandler errorHandler); + + /** + * Register a default + * {@linkplain ResponseSpec#onStatus(ResponseErrorHandler) status handler} + * to apply to every response. Such default handlers are applied in the + * order in which they are registered, and after any others that are + * registered for a specific response. + * @param errorHandler handler that typically, though not necessarily, + * throws an exception + * @return this builder + */ + Builder defaultStatusHandler(ResponseErrorHandler errorHandler); + + /** + * Add the given request interceptor to the end of the interceptor chain. + * @param interceptor the interceptor to be added to the chain + */ + Builder requestInterceptor(ClientHttpRequestInterceptor interceptor); + + /** + * Manipulate the interceptors with the given consumer. The list provided to + * the consumer is "live", so that the consumer can be used to remove + * interceptors, change ordering, etc. + * @param interceptorsConsumer a function that consumes the interceptors list + * @return this builder + */ + Builder requestInterceptors(Consumer> interceptorsConsumer); + + /** + * Add the given request initializer to the end of the initializer chain. + * @param initializer the initializer to be added to the chain + */ + Builder requestInitializer(ClientHttpRequestInitializer initializer); + + /** + * Manipulate the initializers with the given consumer. The list provided to + * the consumer is "live", so that the consumer can be used to remove + * initializers, change ordering, etc. + * @param initializersConsumer a function that consumes the initializers list + * @return this builder + */ + Builder requestInitializers(Consumer> initializersConsumer); + + /** + * Configure the {@link ClientHttpRequestFactory} to use. This is useful + * for plugging in and/or customizing options of the underlying HTTP + * client library (e.g. SSL). + * @param requestFactory the request factory to use + */ + Builder requestFactory(ClientHttpRequestFactory requestFactory); + + /** + * Configure the message converters for the {@code RestClient} to use. + * @param configurer the configurer to apply + */ + Builder messageConverters(Consumer>> configurer); + + /** + * Apply the given {@code Consumer} to this builder instance. + *

This can be useful for applying pre-packaged customizations. + * @param builderConsumer the consumer to apply + */ + Builder apply(Consumer builderConsumer); + + /** + * Clone this {@code RestClient.Builder}. + */ + Builder clone(); + + /** + * Build the {@link RestClient} instance. + */ + RestClient build(); + } + + + /** + * Contract for specifying the URI for a request. + * @param a self reference to the spec type + */ + interface UriSpec> { + + /** + * Specify the URI using an absolute, fully constructed {@link URI}. + */ + S uri(URI uri); + + /** + * Specify the URI for the request using a URI template and URI variables. + * If a {@link UriBuilderFactory} was configured for the client (e.g. + * with a base URI) it will be used to expand the URI template. + */ + S uri(String uri, Object... uriVariables); + + /** + * Specify the URI for the request using a URI template and URI variables. + * If a {@link UriBuilderFactory} was configured for the client (e.g. + * with a base URI) it will be used to expand the URI template. + */ + S uri(String uri, Map uriVariables); + + /** + * Specify the URI starting with a URI template and finishing off with a + * {@link UriBuilder} created from the template. + */ + S uri(String uri, Function uriFunction); + + /** + * Specify the URI by through a {@link UriBuilder}. + * @see #uri(String, Function) + */ + S uri(Function uriFunction); + } + + + /** + * Contract for specifying request headers leading up to the exchange. + * @param a self reference to the spec type + */ + interface RequestHeadersSpec> { + + /** + * Set the list of acceptable {@linkplain MediaType media types}, as + * specified by the {@code Accept} header. + * @param acceptableMediaTypes the acceptable media types + * @return this builder + */ + S accept(MediaType... acceptableMediaTypes); + + /** + * Set the list of acceptable {@linkplain Charset charsets}, as specified + * by the {@code Accept-Charset} header. + * @param acceptableCharsets the acceptable charsets + * @return this builder + */ + S acceptCharset(Charset... acceptableCharsets); + + /** + * Set the value of the {@code If-Modified-Since} header. + *

The date should be specified as the number of milliseconds since + * January 1, 1970 GMT. + * @param ifModifiedSince the new value of the header + * @return this builder + */ + S ifModifiedSince(ZonedDateTime ifModifiedSince); + + /** + * Set the values of the {@code If-None-Match} header. + * @param ifNoneMatches the new value of the header + * @return this builder + */ + S ifNoneMatch(String... ifNoneMatches); + + /** + * Add the given, single header value under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + */ + S header(String headerName, String... headerValues); + + /** + * Provides access to every header declared so far with the possibility + * to add, replace, or remove values. + * @param headersConsumer the consumer to provide access to + * @return this builder + */ + S headers(Consumer headersConsumer); + + /** + * Set the attribute with the given name to the given value. + * @param name the name of the attribute to add + * @param value the value of the attribute to add + * @return this builder + */ + S attribute(String name, Object value); + + /** + * Provides access to every attribute declared so far with the + * possibility to add, replace, or remove values. + * @param attributesConsumer the consumer to provide access to + * @return this builder + */ + S attributes(Consumer> attributesConsumer); + + /** + * Callback for access to the {@link ClientHttpRequest} that in turn + * provides access to the native request of the underlying HTTP library. + * This could be useful for setting advanced, per-request options that + * exposed by the underlying library. + * @param requestConsumer a consumer to access the + * {@code ClientHttpRequest} with + * @return {@code ResponseSpec} to specify how to decode the body + */ + S httpRequest(Consumer requestConsumer); + + /** + * Proceed to declare how to extract the response. For example to extract + * a {@link ResponseEntity} with status, headers, and body: + *

+		 * ResponseEntity<Person> entity = client.get()
+		 *     .uri("/persons/1")
+		 *     .accept(MediaType.APPLICATION_JSON)
+		 *     .retrieve()
+		 *     .toEntity(Person.class);
+		 * 
+ *

Or if interested only in the body: + *

+		 * Person person = client.get()
+		 *     .uri("/persons/1")
+		 *     .accept(MediaType.APPLICATION_JSON)
+		 *     .retrieve()
+		 *     .body(Person.class);
+		 * 
+ *

By default, 4xx response code result in a + * {@link HttpClientErrorException} and 5xx response codes in a + * {@link HttpServerErrorException}. To customize error handling, use + * {@link ResponseSpec#onStatus(Predicate, ResponseSpec.ErrorHandler) onStatus} handlers. + */ + ResponseSpec retrieve(); + + /** + * Exchange the {@link ClientHttpResponse} for a type {@code T}. This + * can be useful for advanced scenarios, for example to decode the + * response differently depending on the response status: + *

+		 * Person person = client.get()
+		 *     .uri("/people/1")
+		 *     .accept(MediaType.APPLICATION_JSON)
+		 *     .exchange((request, response) -> {
+		 *         if (response.getStatusCode().equals(HttpStatus.OK)) {
+		 *             return deserialize(response.getBody());
+		 *         }
+		 *         else {
+		 *             throw new BusinessException();
+		 *         }
+		 *     });
+		 * 
+ *

Note: The response is + * {@linkplain ClientHttpResponse#close() closed} after the exchange + * function has been invoked. + * @param exchangeFunction the function to handle the response with + * @param the type the response will be transformed to + * @return the value returned from the exchange function + */ + T exchange(ExchangeFunction exchangeFunction); + + + /** + * Defines the contract for {@link #exchange(ExchangeFunction)}. + * @param the type the response will be transformed to + */ + @FunctionalInterface + interface ExchangeFunction { + + /** + * Exchange the given response into a type {@code T}. + * @param clientRequest the request + * @param clientResponse the response + * @return the exchanged type + * @throws IOException in case of I/O errors + */ + T exchange(HttpRequest clientRequest, ClientHttpResponse clientResponse) throws IOException; + + } + + } + + + + + /** + * Contract for specifying request headers and body leading up to the exchange. + */ + interface RequestBodySpec extends RequestHeadersSpec { + + /** + * Set the length of the body in bytes, as specified by the + * {@code Content-Length} header. + * @param contentLength the content length + * @return this builder + * @see HttpHeaders#setContentLength(long) + */ + RequestBodySpec contentLength(long contentLength); + + /** + * Set the {@linkplain MediaType media type} of the body, as specified + * by the {@code Content-Type} header. + * @param contentType the content type + * @return this builder + * @see HttpHeaders#setContentType(MediaType) + */ + RequestBodySpec contentType(MediaType contentType); + + /** + * Set the body of the request to the given {@code Object}. + * For example: + *

+		 * Person person = ... ;
+		 * ResponseEntity<Void> response = client.post()
+		 *     .uri("/persons/{id}", id)
+		 *     .contentType(MediaType.APPLICATION_JSON)
+		 *     .body(person)
+		 *     .retrieve()
+		 *     .toBodilessEntity();
+		 * 
+ * @param body the body of the response + * @return the built response + */ + RequestBodySpec body(Object body); + + /** + * Set the body of the response to the given {@code Object}. The parameter + * {@code bodyType} is used to capture the generic type. + * @param body the body of the response + * @param bodyType the type of the body, used to capture the generic type + * @return the built response + */ + RequestBodySpec body(T body, ParameterizedTypeReference bodyType); + + /** + * Set the body of the response to the given function that writes to + * an {@link OutputStream}. + * @param body a function that takes an {@code OutputStream} and can + * throw an {@code IOException} + * @return the built response + */ + RequestBodySpec body(StreamingHttpOutputMessage.Body body); + + } + + + /** + * Contract for specifying response operations following the exchange. + */ + interface ResponseSpec { + + /** + * Provide a function to map specific error status codes to an error + * handler. + *

By default, if there are no matching status handlers, responses + * with status codes >= 400 wil throw a + * {@link RestClientResponseException}. + * @param statusPredicate to match responses with + * @param errorHandler handler that typically, though not necessarily, + * throws an exception + * @return this builder + */ + ResponseSpec onStatus(Predicate statusPredicate, + ErrorHandler errorHandler); + + /** + * Provide a function to map specific error status codes to an error + * handler. + *

By default, if there are no matching status handlers, responses + * with status codes >= 400 wil throw a + * {@link RestClientResponseException}. + * @param errorHandler the error handler + * @return this builder + */ + ResponseSpec onStatus(ResponseErrorHandler errorHandler); + + /** + * Extract the body as an object of the given type. + * @param bodyType the type of return value + * @param the body type + * @return the body, or {@code null} if no response body was available + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + */ + @Nullable + T body(Class bodyType); + + /** + * Extract the body as an object of the given type. + * @param bodyType the type of return value + * @param the body type + * @return the body, or {@code null} if no response body was available + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + */ + @Nullable + T body(ParameterizedTypeReference bodyType); + + /** + * Return a {@code ResponseEntity} with the body decoded to an Object of + * the given type. + * @param bodyType the expected response body type + * @param response body type + * @return the {@code ResponseEntity} with the decoded body + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + */ + ResponseEntity toEntity(Class bodyType); + + /** + * Return a {@code ResponseEntity} with the body decoded to an Object of + * the given type. + * @param bodyType the expected response body type + * @param response body type + * @return the {@code ResponseEntity} with the decoded body + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + */ + ResponseEntity toEntity(ParameterizedTypeReference bodyType); + + /** + * Return a {@code ResponseEntity} without a body. + * @return the {@code ResponseEntity} + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + */ + ResponseEntity toBodilessEntity(); + + + /** + * Used in {@link #onStatus(Predicate, ErrorHandler)}. + */ + @FunctionalInterface + interface ErrorHandler { + + /** + * Handle the error in the given response. + * @param response the response with the error + * @throws IOException in case of I/O errors + */ + void handle(HttpRequest request, ClientHttpResponse response) throws IOException; + + } + + } + + + /** + * Contract for specifying request headers and URI for a request. + * @param a self reference to the spec type + */ + interface RequestHeadersUriSpec> extends UriSpec, RequestHeadersSpec { + } + + + /** + * Contract for specifying request headers, body and URI for a request. + */ + interface RequestBodyUriSpec extends RequestBodySpec, RequestHeadersUriSpec { + } + + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClientUtils.java b/spring-web/src/main/java/org/springframework/web/client/RestClientUtils.java new file mode 100644 index 000000000000..157d11ace424 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/RestClientUtils.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.nio.charset.Charset; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.util.FileCopyUtils; + +/** + * Internal methods shared between types in this package. + * + * @author Arjen Poutsma + * @since 6.1 + */ +abstract class RestClientUtils { + + public static byte[] getBody(HttpInputMessage message) { + try { + return FileCopyUtils.copyToByteArray(message.getBody()); + } + catch (IOException ignore) { + } + return new byte[0]; + } + + @Nullable + public static Charset getCharset(HttpMessage response) { + HttpHeaders headers = response.getHeaders(); + MediaType contentType = headers.getContentType(); + return (contentType != null ? contentType.getCharset() : null); + } +} diff --git a/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java b/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java new file mode 100644 index 000000000000..5083861103a4 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/StatusHandler.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.springframework.core.ResolvableType; +import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; + +/** + * Used by {@link DefaultRestClient} and {@link DefaultRestClientBuilder}. + * + * @author Arjen Poutsma + * @since 6.1 + */ +final class StatusHandler { + + private final ResponsePredicate predicate; + + private final RestClient.ResponseSpec.ErrorHandler errorHandler; + + + private StatusHandler(ResponsePredicate predicate, RestClient.ResponseSpec.ErrorHandler errorHandler) { + this.predicate = predicate; + this.errorHandler = errorHandler; + } + + + public static StatusHandler of(Predicate predicate, + RestClient.ResponseSpec.ErrorHandler errorHandler) { + Assert.notNull(predicate, "Predicate must not be null"); + Assert.notNull(errorHandler, "ErrorHandler must not be null"); + + return new StatusHandler(response -> predicate.test(response.getStatusCode()), errorHandler); + } + + public static StatusHandler fromErrorHandler(ResponseErrorHandler errorHandler) { + Assert.notNull(errorHandler, "ErrorHandler must not be null"); + + return new StatusHandler(errorHandler::hasError, (request, response) -> + errorHandler.handleError(request.getURI(), request.getMethod(), response)); + } + + public static StatusHandler defaultHandler(List> messageConverters) { + return new StatusHandler(response -> response.getStatusCode().isError(), + (request, response) -> { + HttpStatusCode statusCode = response.getStatusCode(); + String statusText = response.getStatusText(); + HttpHeaders headers = response.getHeaders(); + byte[] body = RestClientUtils.getBody(response); + Charset charset = RestClientUtils.getCharset(response); + String message = getErrorMessage(statusCode.value(), statusText, body, charset); + RestClientResponseException ex; + + if (statusCode.is4xxClientError()) { + ex = HttpClientErrorException.create(message, statusCode, statusText, headers, body, charset); + } + else if (statusCode.is5xxServerError()) { + ex = HttpServerErrorException.create(message, statusCode, statusText, headers, body, charset); + } + else { + ex = new UnknownHttpStatusCodeException(message, statusCode.value(), statusText, headers, body, charset); + } + if (!CollectionUtils.isEmpty(messageConverters)) { + ex.setBodyConvertFunction(initBodyConvertFunction(response, body, messageConverters)); + } + throw ex; + }); + } + + private static Function initBodyConvertFunction(ClientHttpResponse response, byte[] body, List> messageConverters) { + Assert.state(!CollectionUtils.isEmpty(messageConverters), "Expected message converters"); + return resolvableType -> { + try { + HttpMessageConverterExtractor extractor = + new HttpMessageConverterExtractor<>(resolvableType.getType(), messageConverters); + + return extractor.extractData(new ClientHttpResponseDecorator(response) { + @Override + public InputStream getBody() { + return new ByteArrayInputStream(body); + } + }); + } + catch (IOException ex) { + throw new RestClientException("Error while extracting response for type [" + resolvableType + "]", ex); + } + }; + } + + + private static String getErrorMessage(int rawStatusCode, String statusText, @Nullable byte[] responseBody, + @Nullable Charset charset) { + + String preface = rawStatusCode + " " + statusText + ": "; + + if (ObjectUtils.isEmpty(responseBody)) { + return preface + "[no body]"; + } + + charset = (charset != null ? charset : StandardCharsets.UTF_8); + + String bodyText = new String(responseBody, charset); + bodyText = LogFormatUtils.formatValue(bodyText, -1, true); + + return preface + bodyText; + } + + + + public boolean test(ClientHttpResponse response) throws IOException { + return this.predicate.test(response); + } + + public void handle(HttpRequest request, ClientHttpResponse response) throws IOException { + this.errorHandler.handle(request, response); + } + + + @FunctionalInterface + private interface ResponsePredicate { + + boolean test(ClientHttpResponse response) throws IOException; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/client/package-info.java b/spring-web/src/main/java/org/springframework/web/client/package-info.java index 0b20179c445f..bc1e16e248bd 100644 --- a/spring-web/src/main/java/org/springframework/web/client/package-info.java +++ b/spring-web/src/main/java/org/springframework/web/client/package-info.java @@ -1,6 +1,6 @@ /** * Core package of the client-side web support. - * Provides a RestTemplate class and various callback interfaces. + * Provides the RestTemplate and RestClient. */ @NonNullApi @NonNullFields diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java new file mode 100644 index 000000000000..35ebf7376ad7 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -0,0 +1,758 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.IOException; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.JdkClientHttpRequestFactory; +import org.springframework.http.client.JettyClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.util.CollectionUtils; +import org.springframework.web.testfixture.xml.Pojo; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Named.named; + +/** + * Integration tests for {@link RestClient}. + * + * @author Arjen Poutsma + */ +class RestClientIntegrationTests { + + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("clientHttpRequestFactories") + @interface ParameterizedWebClientTest { + } + + static Stream> clientHttpRequestFactories() { + return Stream.of( + named("JDK HttpURLConnection", new SimpleClientHttpRequestFactory()), + named("HttpComponents", new HttpComponentsClientHttpRequestFactory()), + named("OkHttp", new OkHttp3ClientHttpRequestFactory()), + named("Jetty", new JettyClientHttpRequestFactory()), + named("JDK HttpClient", new JdkClientHttpRequestFactory()) + ); + } + + + private MockWebServer server; + + private RestClient restClient; + + + private void startServer(ClientHttpRequestFactory requestFactory) { + this.server = new MockWebServer(); + this.restClient = RestClient + .builder() + .requestFactory(requestFactory) + .baseUrl(this.server.url("/").toString()) + .build(); + } + + @AfterEach + void shutdown() throws IOException { + if (server != null) { + this.server.shutdown(); + } + } + + + @ParameterizedWebClientTest + void retrieve(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> + response.setHeader("Content-Type", "text/plain").setBody("Hello Spring!")); + + String result = this.restClient.get() + .uri("/greeting") + .header("X-Test-Header", "testvalue") + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getHeader("X-Test-Header")).isEqualTo("testvalue"); + assertThat(request.getPath()).isEqualTo("/greeting"); + }); + } + + @ParameterizedWebClientTest + void retrieveJson(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response + .setHeader("Content-Type", "application/json") + .setBody("{\"bar\":\"barbar\",\"foo\":\"foofoo\"}")); + + Pojo result = this.restClient.get() + .uri("/pojo") + .accept(MediaType.APPLICATION_JSON) + .retrieve() + .body(Pojo.class); + + assertThat(result.getFoo()).isEqualTo("foofoo"); + assertThat(result.getBar()).isEqualTo("barbar"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/pojo"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieveJsonWithParameterizedTypeReference(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String content = "{\"containerValue\":{\"bar\":\"barbar\",\"foo\":\"foofoo\"}}"; + prepareResponse(response -> response + .setHeader("Content-Type", "application/json").setBody(content)); + + ValueContainer result = this.restClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .body(new ParameterizedTypeReference>() {}); + + assertThat(result.getContainerValue()).isNotNull(); + Pojo pojo = result.getContainerValue(); + assertThat(pojo.getFoo()).isEqualTo("foofoo"); + assertThat(pojo.getBar()).isEqualTo("barbar"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/json"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieveJsonAsResponseEntity(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String content = "{\"bar\":\"barbar\",\"foo\":\"foofoo\"}"; + prepareResponse(response -> response + .setHeader("Content-Type", "application/json").setBody(content)); + + ResponseEntity result = this.restClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .toEntity(String.class); + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(result.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_JSON); + assertThat(result.getHeaders().getContentLength()).isEqualTo(31); + assertThat(result.getBody()).isEqualTo(content); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/json"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieveJsonAsBodilessEntity(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response + .setHeader("Content-Type", "application/json").setBody("{\"bar\":\"barbar\",\"foo\":\"foofoo\"}")); + + ResponseEntity result = this.restClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .toBodilessEntity(); + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(result.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_JSON); + assertThat(result.getHeaders().getContentLength()).isEqualTo(31); + assertThat(result.getBody()).isNull(); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/json"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieveJsonArray(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response + .setHeader("Content-Type", "application/json") + .setBody("[{\"bar\":\"bar1\",\"foo\":\"foo1\"},{\"bar\":\"bar2\",\"foo\":\"foo2\"}]")); + + List result = this.restClient.get() + .uri("/pojos") + .accept(MediaType.APPLICATION_JSON) + .retrieve() + .body(new ParameterizedTypeReference<>() {}); + + assertThat(result).hasSize(2); + assertThat(result.get(0).getFoo()).isEqualTo("foo1"); + assertThat(result.get(0).getBar()).isEqualTo("bar1"); + assertThat(result.get(1).getFoo()).isEqualTo("foo2"); + assertThat(result.get(1).getBar()).isEqualTo("bar2"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/pojos"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieveJsonArrayAsResponseEntityList(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String content = "[{\"bar\":\"bar1\",\"foo\":\"foo1\"}, {\"bar\":\"bar2\",\"foo\":\"foo2\"}]"; + prepareResponse(response -> response + .setHeader("Content-Type", "application/json").setBody(content)); + + ResponseEntity> result = this.restClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() {}); + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(result.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_JSON); + assertThat(result.getHeaders().getContentLength()).isEqualTo(58); + assertThat(result.getBody()).hasSize(2); + assertThat(result.getBody().get(0).getFoo()).isEqualTo("foo1"); + assertThat(result.getBody().get(0).getBar()).isEqualTo("bar1"); + assertThat(result.getBody().get(1).getFoo()).isEqualTo("foo2"); + assertThat(result.getBody().get(1).getBar()).isEqualTo("bar2"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/json"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieveJsonAsSerializedText(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String content = "{\"bar\":\"barbar\",\"foo\":\"foofoo\"}"; + prepareResponse(response -> response + .setHeader("Content-Type", "application/json").setBody(content)); + + String result = this.restClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo(content); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/json"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + @SuppressWarnings("rawtypes") + void retrieveJsonNull(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response + .setResponseCode(200) + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody("null")); + + Map result = this.restClient.get() + .uri("/null") + .retrieve() + .body(Map.class); + + assertThat(result).isNull(); + } + + @ParameterizedWebClientTest + void retrieve404(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(404) + .setHeader("Content-Type", "text/plain")); + + assertThatExceptionOfType(HttpClientErrorException.NotFound.class).isThrownBy(() -> + this.restClient.get().uri("/greeting") + .retrieve() + .body(String.class) + ); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/greeting")); + + } + + @ParameterizedWebClientTest + void retrieve404WithBody(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(404) + .setHeader("Content-Type", "text/plain").setBody("Not Found")); + + assertThatExceptionOfType(HttpClientErrorException.NotFound.class).isThrownBy(() -> + this.restClient.get() + .uri("/greeting") + .retrieve() + .body(String.class) + ); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/greeting")); + } + + @ParameterizedWebClientTest + void retrieve500(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String errorMessage = "Internal Server error"; + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody(errorMessage)); + + String path = "/greeting"; + try { + this.restClient.get() + .uri(path) + .retrieve() + .body(String.class); + } + catch (HttpServerErrorException ex) { + assertThat(ex.getStatusCode()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + assumeFalse(requestFactory instanceof JdkClientHttpRequestFactory, "JDK HttpClient does not expose status text"); + assertThat(ex.getStatusText()).isEqualTo("Server Error"); + assertThat(ex.getResponseHeaders().getContentType()).isEqualTo(MediaType.TEXT_PLAIN); + assertThat(ex.getResponseBodyAsString()).isEqualTo(errorMessage); + } + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo(path)); + } + + @ParameterizedWebClientTest + void retrieve500AsEntity(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + assertThatExceptionOfType(HttpServerErrorException.InternalServerError.class).isThrownBy(() -> + this.restClient.get() + .uri("/").accept(MediaType.APPLICATION_JSON) + .retrieve() + .toEntity(String.class) + ); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieve500AsBodilessEntity(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + assertThatExceptionOfType(HttpServerErrorException.InternalServerError.class).isThrownBy(() -> + this.restClient.get() + .uri("/").accept(MediaType.APPLICATION_JSON) + .retrieve() + .toBodilessEntity() + ); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void retrieve555UnknownStatus(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + int errorStatus = 555; + assertThat(HttpStatus.resolve(errorStatus)).isNull(); + String errorMessage = "Something went wrong"; + prepareResponse(response -> response.setResponseCode(errorStatus) + .setHeader("Content-Type", "text/plain").setBody(errorMessage)); + + try { + this.restClient.get() + .uri("/unknownPage") + .retrieve() + .body(String.class); + + } + catch (HttpServerErrorException ex) { + assumeFalse(requestFactory instanceof JdkClientHttpRequestFactory, "JDK HttpClient does not expose status text"); + assertThat(ex.getMessage()).isEqualTo("555 Server Error: \"Something went wrong\""); + assertThat(ex.getStatusText()).isEqualTo("Server Error"); + assertThat(ex.getResponseHeaders().getContentType()).isEqualTo(MediaType.TEXT_PLAIN); + assertThat(ex.getResponseBodyAsString()).isEqualTo(errorMessage); + } + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/unknownPage")); + } + + @ParameterizedWebClientTest + void postPojoAsJson(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setHeader("Content-Type", "application/json") + .setBody("{\"bar\":\"BARBAR\",\"foo\":\"FOOFOO\"}")); + + Pojo result = this.restClient.post() + .uri("/pojo/capitalize") + .accept(MediaType.APPLICATION_JSON) + .contentType(MediaType.APPLICATION_JSON) + .body(new Pojo("foofoo", "barbar")) + .retrieve() + .body(Pojo.class); + + assertThat(result).isNotNull(); + assertThat(result.getFoo()).isEqualTo("FOOFOO"); + assertThat(result.getBar()).isEqualTo("BARBAR"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/pojo/capitalize"); + assertThat(request.getBody().readUtf8()).isEqualTo("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}"); +// assertThat(request.getHeader(HttpHeaders.CONTENT_LENGTH)).isEqualTo("31"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void statusHandler(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + assertThatExceptionOfType(MyException.class).isThrownBy(() -> + this.restClient.get() + .uri("/greeting") + .retrieve() + .onStatus(HttpStatusCode::is5xxServerError, (request, response) -> { + throw new MyException("500 error!"); + }) + .body(String.class) + ); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/greeting")); + } + + @ParameterizedWebClientTest + void statusHandlerParameterizedTypeReference(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + assertThatExceptionOfType(MyException.class).isThrownBy(() -> + this.restClient.get() + .uri("/greeting") + .retrieve() + .onStatus(HttpStatusCode::is5xxServerError, (request, response) -> { + throw new MyException("500 error!"); + }) + .body(new ParameterizedTypeReference() { + }) + ); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/greeting")); + } + + @ParameterizedWebClientTest + void statusHandlerSuppressedErrorSignal(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + String result = this.restClient.get() + .uri("/greeting") + .retrieve() + .onStatus(HttpStatusCode::is5xxServerError, (request, response) -> {}) + .body(String.class); + + assertThat(result).isEqualTo("Internal Server error"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/greeting")); + } + + @ParameterizedWebClientTest + void statusHandlerSuppressedErrorSignalWithEntity(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String content = "Internal Server error"; + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody(content)); + + ResponseEntity result = this.restClient.get() + .uri("/").accept(MediaType.APPLICATION_JSON) + .retrieve() + .onStatus(HttpStatusCode::is5xxServerError, (request, response) -> {}) + .toEntity(String.class); + + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + assertThat(result.getBody()).isEqualTo(content); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedWebClientTest + void exchangeForPlainText(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setBody("Hello Spring!")); + + String result = this.restClient.get() + .uri("/greeting") + .header("X-Test-Header", "testvalue") + .exchange((request, response) -> new String(RestClientUtils.getBody(response), StandardCharsets.UTF_8)); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getHeader("X-Test-Header")).isEqualTo("testvalue"); + assertThat(request.getPath()).isEqualTo("/greeting"); + }); + } + + @ParameterizedWebClientTest + void exchangeFor404(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(404) + .setHeader("Content-Type", "text/plain").setBody("Not Found")); + + String result = this.restClient.get() + .uri("/greeting") + .exchange((request, response) -> new String(RestClientUtils.getBody(response), StandardCharsets.UTF_8)); + + assertThat(result).isEqualTo("Not Found"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getPath()).isEqualTo("/greeting")); + } + + @ParameterizedWebClientTest + void requestInitializer(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + RestClient initializedClient = this.restClient.mutate() + .requestInitializer(request -> request.getHeaders().add("foo", "bar")) + .build(); + + String result = initializedClient.get() + .uri("/greeting") + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); + } + + @ParameterizedWebClientTest + void requestInterceptor(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + + RestClient interceptedClient = this.restClient.mutate() + .requestInterceptor((request, body, execution) -> { + request.getHeaders().add("foo", "bar"); + return execution.execute(request, body); + }) + .build(); + + String result = interceptedClient.get() + .uri("/greeting") + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); + } + + + @ParameterizedWebClientTest + void filterForErrorHandling(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + ClientHttpResponse response = execution.execute(request, body); + List headerValues = response.getHeaders().get("Foo"); + if (CollectionUtils.isEmpty(headerValues)) { + throw new MyException("Response does not contain Foo header"); + } + else { + return response; + } + }; + + RestClient interceptedClient = this.restClient.mutate().requestInterceptor(interceptor).build(); + + // header not present + prepareResponse(response -> response + .setHeader("Content-Type", "text/plain").setBody("Hello Spring!")); + + assertThatExceptionOfType(MyException.class).isThrownBy(() -> + interceptedClient.get() + .uri("/greeting") + .retrieve() + .body(String.class) + ); + + // header present + + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setHeader("Foo", "Bar") + .setBody("Hello Spring!")); + + String result = interceptedClient.get() + .uri("/greeting") + .retrieve().body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(2); + } + + + @ParameterizedWebClientTest + void invalidDomain(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + String url = "http://example.invalid"; + assertThatExceptionOfType(ResourceAccessException.class).isThrownBy(() -> + this.restClient.get().uri(url).retrieve().toBodilessEntity() + ); + + } + + + private void prepareResponse(Consumer consumer) { + MockResponse response = new MockResponse(); + consumer.accept(response); + this.server.enqueue(response); + } + + private void expectRequest(Consumer consumer) { + try { + consumer.accept(this.server.takeRequest()); + } + catch (InterruptedException ex) { + throw new IllegalStateException(ex); + } + } + + private void expectRequestCount(int count) { + assertThat(this.server.getRequestCount()).isEqualTo(count); + } + + + @SuppressWarnings("serial") + private static class MyException extends RuntimeException { + + MyException(String message) { + super(message); + } + } + + + static class ValueContainer { + + private T containerValue; + + + public T getContainerValue() { + return containerValue; + } + + public void setContainerValue(T containerValue) { + this.containerValue = containerValue; + } + } + +}