diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index 43e8a2cb7885..866ba6e89095 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -419,21 +419,28 @@ private static class DefaultResponseSpec implements ResponseSpec { private static final IntPredicate STATUS_CODE_ERROR = value -> value >= 400; + private static final StatusHandler DEFAULT_STATUS_HANDLER = + new StatusHandler(STATUS_CODE_ERROR, ClientResponse::createException); + + private final Mono responseMono; private final Supplier requestSupplier; private final List statusHandlers = new ArrayList<>(1); + DefaultResponseSpec(Mono responseMono, Supplier requestSupplier) { this.responseMono = responseMono; this.requestSupplier = requestSupplier; - this.statusHandlers.add(new StatusHandler(STATUS_CODE_ERROR, ClientResponse::createException)); + this.statusHandlers.add(DEFAULT_STATUS_HANDLER); } + @Override public ResponseSpec onStatus(Predicate statusPredicate, Function> exceptionFunction) { + return onRawStatus(toIntPredicate(statusPredicate), exceptionFunction); } @@ -450,7 +457,8 @@ public ResponseSpec onRawStatus(IntPredicate statusCodePredicate, Assert.notNull(statusCodePredicate, "IntPredicate must not be null"); Assert.notNull(exceptionFunction, "Function must not be null"); - this.statusHandlers.add(0, new StatusHandler(statusCodePredicate, exceptionFunction)); + int index = this.statusHandlers.size() - 1; // Default handler always last + this.statusHandlers.add(index, new StatusHandler(statusCodePredicate, exceptionFunction)); return this; } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java index 1492e9db76fe..7b4a06277413 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.function.Predicate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -33,6 +34,7 @@ import org.springframework.core.NamedThreadLocal; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import static org.assertj.core.api.Assertions.assertThat; @@ -309,6 +311,48 @@ public void shouldApplyFiltersAtSubscription() { assertThat(request.headers().getFirst("Custom")).isEqualTo("value"); } + @Test // gh-23880 + public void onStatusHandlersOrderIsPreserved() { + + ClientResponse response = ClientResponse.create(HttpStatus.BAD_REQUEST).build(); + given(exchangeFunction.exchange(any())).willReturn(Mono.just(response)); + + Mono result = this.builder.build().get() + .uri("/path") + .retrieve() + .onStatus(HttpStatus::is4xxClientError, resp -> Mono.error(new IllegalStateException("1"))) + .onStatus(HttpStatus::is4xxClientError, resp -> Mono.error(new IllegalStateException("2"))) + .bodyToMono(Void.class); + + StepVerifier.create(result).expectErrorMessage("1").verify(); + } + + @Test // gh-23880 + @SuppressWarnings("unchecked") + public void onStatusHandlersDefaultHandlerIsLast() { + + ClientResponse response = ClientResponse.create(HttpStatus.BAD_REQUEST).build(); + given(exchangeFunction.exchange(any())).willReturn(Mono.just(response)); + + Predicate predicate1 = mock(Predicate.class); + Predicate predicate2 = mock(Predicate.class); + + given(predicate1.test(HttpStatus.BAD_REQUEST)).willReturn(false); + given(predicate2.test(HttpStatus.BAD_REQUEST)).willReturn(false); + + Mono result = this.builder.build().get() + .uri("/path") + .retrieve() + .onStatus(predicate1, resp -> Mono.error(new IllegalStateException())) + .onStatus(predicate2, resp -> Mono.error(new IllegalStateException())) + .bodyToMono(Void.class); + + StepVerifier.create(result).expectError(WebClientResponseException.class).verify(); + + verify(predicate1).test(HttpStatus.BAD_REQUEST); + verify(predicate2).test(HttpStatus.BAD_REQUEST); + } + private ClientRequest verifyAndGetRequest() { ClientRequest request = this.captor.getValue(); verify(this.exchangeFunction).exchange(request);