Skip to content

Commit

Permalink
Add @request(Param/Part) support for multipart requests
Browse files Browse the repository at this point in the history
Issue: SPR-14546
  • Loading branch information
sdeleuze committed Apr 28, 2017
1 parent 4bfd04b commit b804ba1
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 4 deletions.
Expand Up @@ -262,6 +262,7 @@ public Mono<HandlerMethod> getHandlerInternal(ServerWebExchange exchange) {
try {
// Ensure form data is parsed for "params" conditions...
return exchange.getRequestParams()
.then(exchange.getMultipartData())
.then(Mono.defer(() -> {
HandlerMethod handlerMethod = null;
try {
Expand Down
Expand Up @@ -133,6 +133,7 @@ private void addResolversTo(ArgumentResolverRegistrar registrar,

// Annotation-based...
registrar.add(new RequestParamMethodArgumentResolver(beanFactory, reactiveRegistry, false));
registrar.add(new RequestPartMethodArgumentResolver(beanFactory, reactiveRegistry, false));
registrar.add(new RequestParamMapMethodArgumentResolver(reactiveRegistry));
registrar.add(new PathVariableMethodArgumentResolver(beanFactory, reactiveRegistry));
registrar.add(new PathVariableMapMethodArgumentResolver(reactiveRegistry));
Expand Down
Expand Up @@ -21,6 +21,8 @@

import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
Expand All @@ -42,6 +44,7 @@
* request parameters have multiple values.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 5.0
* @see RequestParamMethodArgumentResolver
*/
Expand All @@ -67,12 +70,17 @@ private boolean allParams(RequestParam requestParam, Class<?> type) {
public Optional<Object> resolveArgumentValue(MethodParameter methodParameter,
BindingContext context, ServerWebExchange exchange) {

Class<?> paramType = methodParameter.getParameterType();
boolean isMultiValueMap = MultiValueMap.class.isAssignableFrom(paramType);
ResolvableType paramType = ResolvableType.forType(methodParameter.getGenericParameterType());
boolean isMultiValueMap = MultiValueMap.class.isAssignableFrom(paramType.getRawClass());


if (paramType.getGeneric(1).getRawClass() == Part.class) {
MultiValueMap<String, Part> requestParts = exchange.getMultipartData().subscribe().peek();
Assert.notNull(requestParts, "Expected multipart data (if any) to be parsed.");
return Optional.of(isMultiValueMap ? requestParts : requestParts.toSingleValueMap());
}
MultiValueMap<String, String> requestParams = exchange.getRequestParams().subscribe().peek();
Assert.notNull(requestParams, "Expected form data (if any) to be parsed.");

return Optional.of(isMultiValueMap ? requestParams : requestParams.toSingleValueMap());
}

Expand Down
Expand Up @@ -25,6 +25,7 @@
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -102,7 +103,7 @@ protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) {
protected Optional<Object> resolveNamedValue(String name, MethodParameter parameter,
ServerWebExchange exchange) {

List<String> paramValues = getRequestParams(exchange).get(name);
List<?> paramValues = parameter.getParameter().getType() == Part.class ? getMultipartData(exchange).get(name) : getRequestParams(exchange).get(name);
Object result = null;
if (paramValues != null) {
result = (paramValues.size() == 1 ? paramValues.get(0) : paramValues);
Expand All @@ -116,6 +117,12 @@ private MultiValueMap<String, String> getRequestParams(ServerWebExchange exchang
return params;
}

private MultiValueMap<String, Part> getMultipartData(ServerWebExchange exchange) {
MultiValueMap<String, Part> params = exchange.getMultipartData().subscribe().peek();
Assert.notNull(params, "Expected multipart data (if any) to be parsed.");
return params;
}

@Override
protected void handleMissingValue(String name, MethodParameter parameter, ServerWebExchange exchange) {
String type = parameter.getNestedParameterType().getSimpleName();
Expand Down
@@ -0,0 +1,128 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.reactive.result.method.annotation;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.ValueConstants;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;

/**
* Resolver for method arguments annotated with @{@link RequestPart}.
*
* @author Sebastien Deleuze
* @since 5.0
* @see RequestParamMapMethodArgumentResolver
*/
public class RequestPartMethodArgumentResolver extends AbstractNamedValueSyncArgumentResolver {

private final boolean useDefaultResolution;


/**
* Class constructor with a default resolution mode flag.
* @param factory a bean factory used for resolving ${...} placeholder
* and #{...} SpEL expressions in default values, or {@code null} if default
* values are not expected to contain expressions
* @param registry for checking reactive type wrappers
* @param useDefaultResolution in default resolution mode a method argument
* that is a simple type, as defined in {@link BeanUtils#isSimpleProperty},
* is treated as a request parameter even if it isn't annotated, the
* request parameter name is derived from the method parameter name.
*/
public RequestPartMethodArgumentResolver(
ConfigurableBeanFactory factory, ReactiveAdapterRegistry registry, boolean useDefaultResolution) {

super(factory, registry);
this.useDefaultResolution = useDefaultResolution;
}


@Override
public boolean supportsParameter(MethodParameter param) {
if (checkAnnotatedParamNoReactiveWrapper(param, RequestPart.class, this::singleParam)) {
return true;
}
else if (this.useDefaultResolution) {
return checkParameterTypeNoReactiveWrapper(param, BeanUtils::isSimpleProperty) ||
BeanUtils.isSimpleProperty(param.nestedIfOptional().getNestedParameterType());
}
return false;
}

private boolean singleParam(RequestPart requestParam, Class<?> type) {
return !Map.class.isAssignableFrom(type) || StringUtils.hasText(requestParam.name());
}

@Override
protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) {
RequestPart ann = parameter.getParameterAnnotation(RequestPart.class);
return (ann != null ? new RequestPartNamedValueInfo(ann) : new RequestPartNamedValueInfo());
}

@Override
protected Optional<Object> resolveNamedValue(String name, MethodParameter parameter,
ServerWebExchange exchange) {

List<?> paramValues = getMultipartData(exchange).get(name);
Object result = null;
if (paramValues != null) {
result = (paramValues.size() == 1 ? paramValues.get(0) : paramValues);
}
return Optional.ofNullable(result);
}

private MultiValueMap<String, Part> getMultipartData(ServerWebExchange exchange) {
MultiValueMap<String, Part> params = exchange.getMultipartData().subscribe().peek();
Assert.notNull(params, "Expected multipart data (if any) to be parsed.");
return params;
}

@Override
protected void handleMissingValue(String name, MethodParameter parameter, ServerWebExchange exchange) {
String type = parameter.getNestedParameterType().getSimpleName();
String reason = "Required " + type + " parameter '" + name + "' is not present";
throw new ServerWebInputException(reason, parameter);
}


private static class RequestPartNamedValueInfo extends NamedValueInfo {

RequestPartNamedValueInfo() {
super("", false, ValueConstants.DEFAULT_NONE);
}

RequestPartNamedValueInfo(RequestPart annotation) {
super(annotation.name(), annotation.required(), ValueConstants.DEFAULT_NONE);
}
}

}
Expand Up @@ -91,6 +91,7 @@ public void requestMappingArgumentResolvers() throws Exception {

AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
Expand Down Expand Up @@ -129,6 +130,7 @@ public void modelAttributeArgumentResolvers() throws Exception {

AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
Expand Down Expand Up @@ -165,6 +167,7 @@ public void initBinderArgumentResolvers() throws Exception {

AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
Expand Down Expand Up @@ -195,6 +198,7 @@ public void exceptionHandlerArgumentResolvers() throws Exception {

AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
Expand Down

0 comments on commit b804ba1

Please sign in to comment.