Skip to content

Commit

Permalink
Remove Origin header when forwarding (#3357)
Browse files Browse the repository at this point in the history
This prevents forwarded requests, such as those from
circuit breaker fallbacks, from failing in CORS checks,
which require a fully populated scheme and host.

Fixes gh-3350
  • Loading branch information
spikymonkey committed Apr 16, 2024
1 parent 66aa480 commit d68648c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ public static Mono<Void> handle(DispatcherHandler handler, ServerWebExchange exc
// remove attributes that may disrupt the forwarded request
exchange.getAttributes().remove(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR);

// CORS check is applied to the original request, but should not be applied to
// internally forwarded requests.
// See https://github.com/spring-cloud/spring-cloud-gateway/issues/3350.
exchange = exchange.mutate().request(request -> request.headers(headers -> headers.setOrigin(null))).build();

return handler.handle(exchange);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.web.util.UriComponentsBuilder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand Down Expand Up @@ -93,9 +94,8 @@ public void shouldFilterWhenGatewayRequestUrlSchemeIsForward() {
forwardRoutingFilter.filter(exchange, chain);

verifyNoMoreInteractions(chain);
verify(dispatcherHandler).handle(exchange);

assertThat(exchange.getAttributes().get(GATEWAY_ALREADY_ROUTED_ATTR)).isNull();
verify(dispatcherHandler).handle(
assertArg(exchange -> assertThat(exchange.getAttributes().get(GATEWAY_ALREADY_ROUTED_ATTR)).isNull()));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ public void filterFallbackForward() {
.isOk().expectBody().json("{\"from\":\"circuitbreakerfallbackcontroller3\"}");
}

@Test
public void filterFallbackForwardWithCORS() {
testClient.get().uri("/delay/3?a=b").header("Host", "www.circuitbreakerforward.org")
.header("Origin", "https://cors.withcircuitbreaker.org").exchange().expectStatus().isOk().expectBody()
.json("{\"from\":\"circuitbreakerfallbackcontroller3\"}");
}

@Test
public void filterStatusCodeFallback() {
testClient.get().uri("/status/500").header("Host", "www.circuitbreakerstatuscode.org").exchange().expectStatus()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,27 @@

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import reactor.core.publisher.Mono;

import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_PREDICATE_PATH_CONTAINER_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.expand;
import static org.springframework.http.server.PathContainer.parsePath;

public class ServerWebExchangeUtilsTests {

Expand Down Expand Up @@ -94,6 +103,26 @@ public void duplicatedCachingDataBufferHandling() {
Assertions.assertThat(dataBufferBeforeCaching).isEqualTo(dataBufferAfterCached);
}

@Test
public void forwardedRequestsHaveDisruptiveAttributesAndHeadersRemoved() {
DispatcherHandler handler = Mockito.mock(DispatcherHandler.class);
Mockito.when(handler.handle(any(ServerWebExchange.class))).thenReturn(Mono.empty());

ServerWebExchange originalExchange = mockExchange(Map.of()).mutate()
.request(request -> request.headers(headers -> headers.setOrigin("https://example.com"))).build();
originalExchange.getAttributes().put(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR, parsePath("/example/path"));

ServerWebExchangeUtils.handle(handler, originalExchange).block();

Mockito.verify(handler).handle(assertArg(exchange -> {
Assertions.assertThat(exchange.getAttributes()).as("exchange attributes")
.doesNotContainKey(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR);

Assertions.assertThat(exchange.getRequest().getHeaders()).as("request headers")
.doesNotContainKey(HttpHeaders.ORIGIN);
}));
}

private MockServerWebExchange mockExchange(Map<String, String> vars) {
return mockExchange(HttpMethod.GET, vars);
}
Expand Down

0 comments on commit d68648c

Please sign in to comment.