Skip to content

Commit

Permalink
Empty body checks in ConsumesRequestCondition
Browse files Browse the repository at this point in the history
Normally consumes matches the "Content-Type" header but what should be done if
there is no content? This commit adds checks for method parameters with
@RequestBody(required=false) and if "false" then also match requests with no content.

Closes gh-22010
  • Loading branch information
rstoyanchev committed May 8, 2019
1 parent cdf51c3 commit 45147c2
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
import java.util.List;
import java.util.Set;

import org.springframework.http.HttpHeaders;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.cors.reactive.CorsUtils;
import org.springframework.web.server.ServerWebExchange;
Expand All @@ -50,6 +53,8 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition<Con

private final List<ConsumeMediaTypeExpression> expressions;

private boolean bodyRequired = true;


/**
* Creates a new instance from 0 or more "consumes" expressions.
Expand Down Expand Up @@ -141,6 +146,29 @@ protected String getToStringInfix() {
return " || ";
}

/**
* Whether this condition should expect requests to have a body.
* <p>By default this is set to {@code true} in which case it is assumed a
* request body is required and this condition matches to the "Content-Type"
* header or falls back on "Content-Type: application/octet-stream".
* <p>If set to {@code false}, and the request does not have a body, then this
* condition matches automatically, i.e. without checking expressions.
* @param bodyRequired whether requests are expected to have a body
* @since 5.2
*/
public void setBodyRequired(boolean bodyRequired) {
this.bodyRequired = bodyRequired;
}

/**
* Return the setting for {@link #setBodyRequired(boolean)}.
* @since 5.2
*/
public boolean isBodyRequired() {
return this.bodyRequired;
}


/**
* Returns the "other" instance if it has any expressions; returns "this"
* instance otherwise. Practically that means a method-level "consumes"
Expand All @@ -163,16 +191,27 @@ public ConsumesRequestCondition combine(ConsumesRequestCondition other) {
*/
@Override
public ConsumesRequestCondition getMatchingCondition(ServerWebExchange exchange) {
if (CorsUtils.isPreFlightRequest(exchange.getRequest())) {
ServerHttpRequest request = exchange.getRequest();
if (CorsUtils.isPreFlightRequest(request)) {
return EMPTY_CONDITION;
}
if (isEmpty()) {
return this;
}
if (!hasBody(request) && !this.bodyRequired) {
return EMPTY_CONDITION;
}
List<ConsumeMediaTypeExpression> result = getMatchingExpressions(exchange);
return !CollectionUtils.isEmpty(result) ? new ConsumesRequestCondition(result) : null;
}

private boolean hasBody(ServerHttpRequest request) {
String contentLength = request.getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH);
String transferEncoding = request.getHeaders().getFirst(HttpHeaders.TRANSFER_ENCODING);
return StringUtils.hasText(transferEncoding) ||
(StringUtils.hasText(contentLength) && !contentLength.trim().equals("0"));
}

@Nullable
private List<ConsumeMediaTypeExpression> getMatchingExpressions(ServerWebExchange exchange) {
List<ConsumeMediaTypeExpression> result = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,31 @@

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Predicate;

import org.springframework.context.EmbeddedValueResolverAware;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.util.StringValueResolver;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
import org.springframework.web.reactive.result.condition.ConsumesRequestCondition;
import org.springframework.web.reactive.result.condition.RequestCondition;
import org.springframework.web.reactive.result.method.RequestMappingInfo;
import org.springframework.web.reactive.result.method.RequestMappingInfoHandlerMapping;
Expand Down Expand Up @@ -255,6 +260,31 @@ protected String[] resolveEmbeddedValuesInPatterns(String[] patterns) {
}
}

@Override
public void registerMapping(RequestMappingInfo mapping, Object handler, Method method) {
super.registerMapping(mapping, handler, method);
updateConsumesCondition(mapping, method);
}

@Override
protected void registerHandlerMethod(Object handler, Method method, RequestMappingInfo mapping) {
super.registerHandlerMethod(handler, method, mapping);
updateConsumesCondition(mapping, method);
}

private void updateConsumesCondition(RequestMappingInfo info, Method method) {
ConsumesRequestCondition condition = info.getConsumesCondition();
if (!condition.isEmpty()) {
for (Parameter parameter : method.getParameters()) {
MergedAnnotation<RequestBody> annot = MergedAnnotations.from(parameter).get(RequestBody.class);
if (annot.isPresent()) {
condition.setBodyRequired(annot.getBoolean("required"));
break;
}
}
}
}

@Override
protected CorsConfiguration initCorsConfiguration(Object handler, Method method, RequestMappingInfo mappingInfo) {
HandlerMethod handlerMethod = createHandlerMethod(handler, method);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2012 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -99,6 +99,24 @@ public void consumesParseErrorWithNegation() throws Exception {
assertNull(condition.getMatchingCondition(exchange));
}

@Test // gh-22010
public void consumesNoContent() {
ConsumesRequestCondition condition = new ConsumesRequestCondition("text/plain");
condition.setBodyRequired(false);

MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
assertNotNull(condition.getMatchingCondition(MockServerWebExchange.from(request)));

request = MockServerHttpRequest.get("/").header(HttpHeaders.CONTENT_LENGTH, "0").build();
assertNotNull(condition.getMatchingCondition(MockServerWebExchange.from(request)));

request = MockServerHttpRequest.get("/").header(HttpHeaders.CONTENT_LENGTH, "21").build();
assertNull(condition.getMatchingCondition(MockServerWebExchange.from(request)));

request = MockServerHttpRequest.get("/").header(HttpHeaders.TRANSFER_ENCODING, "chunked").build();
assertNull(condition.getMatchingCondition(MockServerWebExchange.from(request)));
}

@Test
public void compareToSingle() throws Exception {
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.lang.annotation.Target;
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Set;

Expand All @@ -32,16 +31,20 @@
import org.springframework.core.annotation.AliasFor;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.util.ClassUtils;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PatchMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.method.HandlerTypePredicate;
import org.springframework.web.reactive.result.condition.ConsumesRequestCondition;
import org.springframework.web.reactive.result.condition.PatternsRequestCondition;
import org.springframework.web.reactive.result.method.RequestMappingInfo;
import org.springframework.web.util.pattern.PathPattern;
import org.springframework.web.util.pattern.PathPatternParser;
Expand Down Expand Up @@ -103,10 +106,26 @@ public void resolveRequestMappingViaComposedAnnotation() throws Exception {

@Test // SPR-14988
public void getMappingOverridesConsumesFromTypeLevelAnnotation() throws Exception {
RequestMappingInfo requestMappingInfo = assertComposedAnnotationMapping(RequestMethod.GET);
RequestMappingInfo requestMappingInfo = assertComposedAnnotationMapping(RequestMethod.POST);

assertArrayEquals(new MediaType[]{MediaType.ALL}, new ArrayList<>(
requestMappingInfo.getConsumesCondition().getConsumableMediaTypes()).toArray());
ConsumesRequestCondition condition = requestMappingInfo.getConsumesCondition();
assertEquals(Collections.singleton(MediaType.APPLICATION_XML), condition.getConsumableMediaTypes());
}

@Test // gh-22010
public void consumesWithOptionalRequestBody() {
this.wac.registerSingleton("testController", ComposedAnnotationController.class);
this.wac.refresh();
this.handlerMapping.afterPropertiesSet();
RequestMappingInfo info = this.handlerMapping.getHandlerMethods().keySet().stream()
.filter(i -> {
PatternsRequestCondition condition = i.getPatternsCondition();
return condition.getPatterns().iterator().next().getPatternString().equals("/post");
})
.findFirst()
.orElseThrow(() -> new AssertionError("No /post"));

assertFalse(info.getConsumesCondition().isBodyRequired());
}

@Test
Expand Down Expand Up @@ -146,7 +165,7 @@ private RequestMappingInfo assertComposedAnnotationMapping(String methodName, St
RequestMethod requestMethod) throws Exception {

Class<?> clazz = ComposedAnnotationController.class;
Method method = clazz.getMethod(methodName);
Method method = ClassUtils.getMethod(clazz, methodName, null);
RequestMappingInfo info = this.handlerMapping.getMappingForMethod(method, clazz);

assertNotNull(info);
Expand Down Expand Up @@ -175,12 +194,12 @@ public void handle() {
public void postJson() {
}

@GetMapping(value = "/get", consumes = MediaType.ALL_VALUE)
@GetMapping("/get")
public void get() {
}

@PostMapping("/post")
public void post() {
@PostMapping(path = "/post", consumes = MediaType.APPLICATION_XML_VALUE)
public void post(@RequestBody(required = false) Foo foo) {
}

@PutMapping("/put")
Expand All @@ -196,6 +215,9 @@ public void patch() {
}
}

private static class Foo {
}


@RequestMapping(method = RequestMethod.POST,
produces = MediaType.APPLICATION_JSON_VALUE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Set;
import javax.servlet.http.HttpServletRequest;

import org.springframework.http.HttpHeaders;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
Expand All @@ -49,8 +50,11 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition<Con

private static final ConsumesRequestCondition EMPTY_CONDITION = new ConsumesRequestCondition();


private final List<ConsumeMediaTypeExpression> expressions;

private boolean bodyRequired = true;


/**
* Creates a new instance from 0 or more "consumes" expressions.
Expand Down Expand Up @@ -141,6 +145,29 @@ protected String getToStringInfix() {
return " || ";
}

/**
* Whether this condition should expect requests to have a body.
* <p>By default this is set to {@code true} in which case it is assumed a
* request body is required and this condition matches to the "Content-Type"
* header or falls back on "Content-Type: application/octet-stream".
* <p>If set to {@code false}, and the request does not have a body, then this
* condition matches automatically, i.e. without checking expressions.
* @param bodyRequired whether requests are expected to have a body
* @since 5.2
*/
public void setBodyRequired(boolean bodyRequired) {
this.bodyRequired = bodyRequired;
}

/**
* Return the setting for {@link #setBodyRequired(boolean)}.
* @since 5.2
*/
public boolean isBodyRequired() {
return this.bodyRequired;
}


/**
* Returns the "other" instance if it has any expressions; returns "this"
* instance otherwise. Practically that means a method-level "consumes"
Expand Down Expand Up @@ -170,14 +197,17 @@ public ConsumesRequestCondition getMatchingCondition(HttpServletRequest request)
if (isEmpty()) {
return this;
}
if (!hasBody(request) && !this.bodyRequired) {
return EMPTY_CONDITION;
}

// Common media types are cached at the level of MimeTypeUtils

MediaType contentType;
try {
contentType = (StringUtils.hasLength(request.getContentType()) ?
contentType = StringUtils.hasLength(request.getContentType()) ?
MediaType.parseMediaType(request.getContentType()) :
MediaType.APPLICATION_OCTET_STREAM);
MediaType.APPLICATION_OCTET_STREAM;
}
catch (InvalidMediaTypeException ex) {
return null;
Expand All @@ -187,6 +217,13 @@ public ConsumesRequestCondition getMatchingCondition(HttpServletRequest request)
return !CollectionUtils.isEmpty(result) ? new ConsumesRequestCondition(result) : null;
}

private boolean hasBody(HttpServletRequest request) {
String contentLength = request.getHeader(HttpHeaders.CONTENT_LENGTH);
String transferEncoding = request.getHeader(HttpHeaders.TRANSFER_ENCODING);
return StringUtils.hasText(transferEncoding) ||
(StringUtils.hasText(contentLength) && !contentLength.trim().equals("0"));
}

@Nullable
private List<ConsumeMediaTypeExpression> getMatchingExpressions(MediaType contentType) {
List<ConsumeMediaTypeExpression> result = null;
Expand Down
Loading

0 comments on commit 45147c2

Please sign in to comment.