Skip to content

Commit

Permalink
Preserve order of onStatus handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev authored and pull[bot] committed Oct 31, 2019
1 parent 9b09ee4 commit 3b3eb0f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -419,21 +419,28 @@ private static class DefaultResponseSpec implements ResponseSpec {

private static final IntPredicate STATUS_CODE_ERROR = value -> value >= 400;

private static final StatusHandler DEFAULT_STATUS_HANDLER =
new StatusHandler(STATUS_CODE_ERROR, ClientResponse::createException);


private final Mono<ClientResponse> responseMono;

private final Supplier<HttpRequest> requestSupplier;

private final List<StatusHandler> statusHandlers = new ArrayList<>(1);


DefaultResponseSpec(Mono<ClientResponse> responseMono, Supplier<HttpRequest> requestSupplier) {
this.responseMono = responseMono;
this.requestSupplier = requestSupplier;
this.statusHandlers.add(new StatusHandler(STATUS_CODE_ERROR, ClientResponse::createException));
this.statusHandlers.add(DEFAULT_STATUS_HANDLER);
}


@Override
public ResponseSpec onStatus(Predicate<HttpStatus> statusPredicate,
Function<ClientResponse, Mono<? extends Throwable>> exceptionFunction) {

return onRawStatus(toIntPredicate(statusPredicate), exceptionFunction);
}

Expand All @@ -450,7 +457,8 @@ public ResponseSpec onRawStatus(IntPredicate statusCodePredicate,

Assert.notNull(statusCodePredicate, "IntPredicate must not be null");
Assert.notNull(exceptionFunction, "Function must not be null");
this.statusHandlers.add(0, new StatusHandler(statusCodePredicate, exceptionFunction));
int index = this.statusHandlers.size() - 1; // Default handler always last
this.statusHandlers.add(index, new StatusHandler(statusCodePredicate, exceptionFunction));
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -33,6 +34,7 @@

import org.springframework.core.NamedThreadLocal;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -309,6 +311,48 @@ public void shouldApplyFiltersAtSubscription() {
assertThat(request.headers().getFirst("Custom")).isEqualTo("value");
}

@Test // gh-23880
public void onStatusHandlersOrderIsPreserved() {

ClientResponse response = ClientResponse.create(HttpStatus.BAD_REQUEST).build();
given(exchangeFunction.exchange(any())).willReturn(Mono.just(response));

Mono<Void> result = this.builder.build().get()
.uri("/path")
.retrieve()
.onStatus(HttpStatus::is4xxClientError, resp -> Mono.error(new IllegalStateException("1")))
.onStatus(HttpStatus::is4xxClientError, resp -> Mono.error(new IllegalStateException("2")))
.bodyToMono(Void.class);

StepVerifier.create(result).expectErrorMessage("1").verify();
}

@Test // gh-23880
@SuppressWarnings("unchecked")
public void onStatusHandlersDefaultHandlerIsLast() {

ClientResponse response = ClientResponse.create(HttpStatus.BAD_REQUEST).build();
given(exchangeFunction.exchange(any())).willReturn(Mono.just(response));

Predicate<HttpStatus> predicate1 = mock(Predicate.class);
Predicate<HttpStatus> predicate2 = mock(Predicate.class);

given(predicate1.test(HttpStatus.BAD_REQUEST)).willReturn(false);
given(predicate2.test(HttpStatus.BAD_REQUEST)).willReturn(false);

Mono<Void> result = this.builder.build().get()
.uri("/path")
.retrieve()
.onStatus(predicate1, resp -> Mono.error(new IllegalStateException()))
.onStatus(predicate2, resp -> Mono.error(new IllegalStateException()))
.bodyToMono(Void.class);

StepVerifier.create(result).expectError(WebClientResponseException.class).verify();

verify(predicate1).test(HttpStatus.BAD_REQUEST);
verify(predicate2).test(HttpStatus.BAD_REQUEST);
}

private ClientRequest verifyAndGetRequest() {
ClientRequest request = this.captor.getValue();
verify(this.exchangeFunction).exchange(request);
Expand Down

0 comments on commit 3b3eb0f

Please sign in to comment.