Skip to content

Commit

Permalink
Update GraphQlRequestPredicates to map application/graphql content
Browse files Browse the repository at this point in the history
Closes gh-948
  • Loading branch information
rstoyanchev committed Apr 22, 2024
1 parent 03c11d0 commit b1cb364
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 34 deletions.
Expand Up @@ -16,7 +16,6 @@

package org.springframework.graphql.server.webflux;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

Expand All @@ -40,6 +39,7 @@
* {@link RequestPredicate} implementations tailored for GraphQL reactive endpoints.
*
* @author Brian Clozel
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public final class GraphQlRequestPredicates {
Expand All @@ -56,7 +56,8 @@ private GraphQlRequestPredicates() {
* @see GraphQlHttpHandler
*/
public static RequestPredicate graphQlHttp(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE);
return new GraphQlHttpRequestPredicate(
path, List.of(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE));
}

/**
Expand All @@ -65,59 +66,67 @@ public static RequestPredicate graphQlHttp(String path) {
* @see GraphQlSseHandler
*/
public static RequestPredicate graphQlSse(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM);
return new GraphQlHttpRequestPredicate(path, List.of(MediaType.TEXT_EVENT_STREAM));
}

private static class GraphQlHttpRequestPredicate implements RequestPredicate {

private final PathPattern pattern;

private final List<MediaType> contentTypes;

private final List<MediaType> acceptedMediaTypes;


GraphQlHttpRequestPredicate(String path, MediaType... accepted) {
GraphQlHttpRequestPredicate(String path, List<MediaType> accepted) {
Assert.notNull(path, "'path' must not be null");
Assert.notEmpty(accepted, "'accepted' must not be empty");
PathPatternParser parser = PathPatternParser.defaultInstance;
path = parser.initFullPathPattern(path);
this.pattern = parser.parse(path);
this.acceptedMediaTypes = Arrays.asList(accepted);
this.contentTypes = List.of(MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/graphql"));
this.acceptedMediaTypes = accepted;
}

@Override
public boolean test(ServerRequest request) {
return methodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, MediaType.APPLICATION_JSON)
return httpMethodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, this.contentTypes)
&& acceptMatch(request, this.acceptedMediaTypes)
&& pathMatch(request, this.pattern);
}

private static boolean methodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveMethod(request);
private static boolean httpMethodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveHttpMethod(request);
boolean methodMatch = expected.equals(actual);
traceMatch("Method", expected, actual, methodMatch);
return methodMatch;
}

private static HttpMethod resolveMethod(ServerRequest request) {
private static HttpMethod resolveHttpMethod(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (accessControlRequestMethod != null) {
return HttpMethod.valueOf(accessControlRequestMethod);
String httpMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (httpMethod != null) {
return HttpMethod.valueOf(httpMethod);
}
}
return request.method();
}

private static boolean contentTypeMatch(ServerRequest request, MediaType expected) {
private static boolean contentTypeMatch(ServerRequest request, List<MediaType> contentTypes) {
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
return true;
}
ServerRequest.Headers headers = request.headers();
MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean contentTypeMatch = expected.includes(actual);
traceMatch("Content-Type", expected, actual, contentTypeMatch);
boolean contentTypeMatch = false;
for (MediaType contentType : contentTypes) {
contentTypeMatch = contentType.includes(actual);
traceMatch("Content-Type", contentTypes, actual, contentTypeMatch);
if (contentTypeMatch) {
break;
}
}
return contentTypeMatch;
}

Expand Down
Expand Up @@ -16,7 +16,6 @@

package org.springframework.graphql.server.webmvc;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

Expand All @@ -40,6 +39,7 @@
* {@link RequestPredicate} implementations tailored for GraphQL endpoints.
*
* @author Brian Clozel
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public final class GraphQlRequestPredicates {
Expand All @@ -56,7 +56,8 @@ private GraphQlRequestPredicates() {
* @see GraphQlHttpHandler
*/
public static RequestPredicate graphQlHttp(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE);
return new GraphQlHttpRequestPredicate(
path, List.of(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE));
}

/**
Expand All @@ -65,59 +66,67 @@ public static RequestPredicate graphQlHttp(String path) {
* @see GraphQlSseHandler
*/
public static RequestPredicate graphQlSse(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM);
return new GraphQlHttpRequestPredicate(path, List.of(MediaType.TEXT_EVENT_STREAM));
}

private static class GraphQlHttpRequestPredicate implements RequestPredicate {

private final PathPattern pattern;

private final List<MediaType> contentTypes;

private final List<MediaType> acceptedMediaTypes;


GraphQlHttpRequestPredicate(String path, MediaType... accepted) {
GraphQlHttpRequestPredicate(String path, List<MediaType> accepted) {
Assert.notNull(path, "'path' must not be null");
Assert.notEmpty(accepted, "'accepted' must not be empty");
PathPatternParser parser = PathPatternParser.defaultInstance;
path = parser.initFullPathPattern(path);
this.pattern = parser.parse(path);
this.acceptedMediaTypes = Arrays.asList(accepted);
this.contentTypes = List.of(MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/graphql"));
this.acceptedMediaTypes = accepted;
}

@Override
public boolean test(ServerRequest request) {
return methodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, MediaType.APPLICATION_JSON)
return httpMethodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, this.contentTypes)
&& acceptMatch(request, this.acceptedMediaTypes)
&& pathMatch(request, this.pattern);
}

private static boolean methodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveMethod(request);
private static boolean httpMethodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveHttpMethod(request);
boolean methodMatch = expected.equals(actual);
traceMatch("Method", expected, actual, methodMatch);
return methodMatch;
}

private static HttpMethod resolveMethod(ServerRequest request) {
private static HttpMethod resolveHttpMethod(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (accessControlRequestMethod != null) {
return HttpMethod.valueOf(accessControlRequestMethod);
String httpMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (httpMethod != null) {
return HttpMethod.valueOf(httpMethod);
}
}
return request.method();
}

private static boolean contentTypeMatch(ServerRequest request, MediaType expected) {
private static boolean contentTypeMatch(ServerRequest request, List<MediaType> contentTypes) {
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
return true;
}
ServerRequest.Headers headers = request.headers();
MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean contentTypeMatch = expected.includes(actual);
traceMatch("Content-Type", expected, actual, contentTypeMatch);
boolean contentTypeMatch = false;
for (MediaType contentType : contentTypes) {
contentTypeMatch = contentType.includes(actual);
traceMatch("Content-Type", contentTypes, actual, contentTypeMatch);
if (contentTypeMatch) {
break;
}
}
return contentTypeMatch;
}

Expand Down
Expand Up @@ -77,6 +77,18 @@ void shouldRejectRequestWithDifferentPath() {
assertThat(httpPredicate.test(serverRequest)).isFalse();
}

@Test
void shouldMapApplicationGraphQlRequestContent() {
ServerWebExchange exchange = createMatchingHttpExchange()
.mutate().request(builder -> builder.headers(headers -> {
MediaType contentType = MediaType.parseMediaType("application/graphql");
headers.setContentType(contentType);
}))
.build();
ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList());
assertThat(httpPredicate.test(serverRequest)).isTrue();
}

@Test
void shouldRejectRequestWithDifferentContentType() {
ServerWebExchange exchange = createMatchingHttpExchange()
Expand Down
Expand Up @@ -74,6 +74,14 @@ void shouldRejectRequestWithDifferentPath() {
assertThat(httpPredicate.test(serverRequest)).isFalse();
}

@Test
void shouldMapApplicationGraphQlRequestContent() {
MockHttpServletRequest request = createMatchingHttpRequest();
request.setContentType("application/graphql");
ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList());
assertThat(httpPredicate.test(serverRequest)).isTrue();
}

@Test
void shouldRejectRequestWithDifferentContentType() {
MockHttpServletRequest request = createMatchingHttpRequest();
Expand Down

0 comments on commit b1cb364

Please sign in to comment.