Skip to content

Commit

Permalink
Support multipart/* MediaTypes in RestTemplate
Browse files Browse the repository at this point in the history
Prior to this commit, RestTemplate posted multipart with Content-Type
"multipart/form-data" even if the FormHttpMessageConverter configured
in the RestTemplate had been configured to support additional multipart
subtypes. This made it impossible to POST form data using a content
type such as "multipart/mixed" or "multipart/related".

This commit addresses this issue by updating FormHttpMessageConverter
to support custom multipart subtypes for writing form data.

For example, the following use case is now supported.

MediaType multipartMixed = new MediaType("multipart", "mixed");

restTemplate.getMessageConverters().stream()
    .filter(FormHttpMessageConverter.class::isInstance)
    .map(FormHttpMessageConverter.class::cast)
    .findFirst()
    .orElseThrow(() ->
        new IllegalStateException("Failed to find FormHttpMessageConverter"))
    .addSupportedMediaTypes(multipartMixed);

MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
parts.add("field 1", "value 1");
parts.add("file", new ClassPathResource("myFile.jpg"));

HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.setContentType(multipartMixed);
HttpEntity<MultiValueMap<String, Object>> requestEntity =
    new HttpEntity<>(parts, requestHeaders);

restTemplate.postForLocation("https://example.com/myFileUpload", requestEntity);

Closes gh-23159
  • Loading branch information
sbrannen committed Jun 28, 2019
1 parent 7bc727c commit 5008423
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 46 deletions.
Expand Up @@ -55,33 +55,76 @@
* write (but not read) the {@code "multipart/form-data"} media type as
* {@link MultiValueMap MultiValueMap&lt;String, Object&gt;}.
*
* <h3>Multipart Data</h3>
*
* <p>By default, {@code "multipart/form-data"} is used as the content type when
* {@linkplain #write writing} multipart data. As of Spring Framework 5.2 it is
* also possible to write multipart data using other multipart subtypes such as
* {@code "multipart/mixed"} and {@code "multipart/related"}, as long as the
* multipart subtype is registered as a {@linkplain #getSupportedMediaTypes
* supported media type} <em>and</em> the desired multipart subtype is specified
* as the content type when {@linkplain #write writing} the multipart data.
*
* <p>When writing multipart data, this converter uses other
* {@link HttpMessageConverter HttpMessageConverters} to write the respective
* MIME parts. By default, basic converters are registered (e.g., for {@code String}
* and {@code Resource}). These can be overridden through the
* {@link #setPartConverters partConverters} property.
* MIME parts. By default, basic converters are registered for byte array,
* {@code String}, and {@code Resource}. These can be overridden via
* {@link #setPartConverters} or augmented via {@link #addPartConverter}.
*
* <h3>Examples</h3>
*
* <p>The following snippet shows how to submit an HTML form using the
* {@code "multipart/form-data"} content type.
*
* <p>For example, the following snippet shows how to submit an HTML form:
* <pre class="code">
* RestTemplate template = new RestTemplate();
* RestTemplate restTemplate = new RestTemplate();
* // AllEncompassingFormHttpMessageConverter is configured by default
*
* MultiValueMap&lt;String, Object&gt; form = new LinkedMultiValueMap&lt;&gt;();
* form.add("field 1", "value 1");
* form.add("field 2", "value 2");
* form.add("field 2", "value 3");
* form.add("field 3", 4); // non-String form values supported as of 5.1.4
* template.postForLocation("https://example.com/myForm", form);
* restTemplate.postForLocation("https://example.com/myForm", form);
* </pre>
*
* <p>The following snippet shows how to do a file upload:
* <p>The following snippet shows how to do a file upload using the
* {@code "multipart/form-data"} content type.
*
* <pre class="code">
* MultiValueMap&lt;String, Object&gt; parts = new LinkedMultiValueMap&lt;&gt;();
* parts.add("field 1", "value 1");
* parts.add("file", new ClassPathResource("myFile.jpg"));
* template.postForLocation("https://example.com/myFileUpload", parts);
* restTemplate.postForLocation("https://example.com/myFileUpload", parts);
* </pre>
*
* <p>The following snippet shows how to do a file upload using the
* {@code "multipart/mixed"} content type.
*
* <pre class="code">
* MediaType multipartMixed = new MediaType("multipart", "mixed");
*
* restTemplate.getMessageConverters().stream()
* .filter(FormHttpMessageConverter.class::isInstance)
* .map(FormHttpMessageConverter.class::cast)
* .findFirst()
* .orElseThrow(() -&gt; new IllegalStateException("Failed to find FormHttpMessageConverter"))
* .addSupportedMediaTypes(multipartMixed);
*
* MultiValueMap&lt;String, Object&gt; parts = new LinkedMultiValueMap&lt;&gt;();
* parts.add("field 1", "value 1");
* parts.add("file", new ClassPathResource("myFile.jpg"));
*
* HttpHeaders requestHeaders = new HttpHeaders();
* requestHeaders.setContentType(multipartMixed);
* HttpEntity&lt;MultiValueMap&lt;String, Object&gt;&gt; requestEntity =
* new HttpEntity&lt;&gt;(parts, requestHeaders);
*
* restTemplate.postForLocation("https://example.com/myFileUpload", requestEntity);
* </pre>
*
* <h3>Miscellaneous</h3>
*
* <p>Some methods in this class were inspired by
* {@code org.apache.commons.httpclient.methods.multipart.MultipartRequestEntity}.
*
Expand All @@ -95,6 +138,8 @@
*/
public class FormHttpMessageConverter implements HttpMessageConverter<MultiValueMap<String, ?>> {

private static final MediaType MULTIPART_ALL = new MediaType("multipart", "*");

/**
* The default charset used by the converter.
*/
Expand Down Expand Up @@ -154,6 +199,12 @@ public void addSupportedMediaTypes(MediaType... supportedMediaTypes) {
}
}

/**
* {@inheritDoc}
*
* @see #setSupportedMediaTypes(List)
* @see #addSupportedMediaTypes(MediaType...)
*/
@Override
public List<MediaType> getSupportedMediaTypes() {
return Collections.unmodifiableList(this.supportedMediaTypes);
Expand Down Expand Up @@ -236,8 +287,11 @@ public boolean canRead(Class<?> clazz, @Nullable MediaType mediaType) {
return true;
}
for (MediaType supportedMediaType : getSupportedMediaTypes()) {
// We can't read multipart....
if (!supportedMediaType.equals(MediaType.MULTIPART_FORM_DATA) && supportedMediaType.includes(mediaType)) {
if (MULTIPART_ALL.includes(supportedMediaType)) {
// We can't read multipart, so skip this supported media type.
continue;
}
if (supportedMediaType.includes(mediaType)) {
return true;
}
}
Expand Down Expand Up @@ -291,7 +345,7 @@ public void write(MultiValueMap<String, ?> map, @Nullable MediaType contentType,
throws IOException, HttpMessageNotWritableException {

if (isMultipart(map, contentType)) {
writeMultipart((MultiValueMap<String, Object>) map, outputMessage);
writeMultipart((MultiValueMap<String, Object>) map, contentType, outputMessage);
}
else {
writeForm((MultiValueMap<String, Object>) map, contentType, outputMessage);
Expand All @@ -301,7 +355,7 @@ public void write(MultiValueMap<String, ?> map, @Nullable MediaType contentType,

private boolean isMultipart(MultiValueMap<String, ?> map, @Nullable MediaType contentType) {
if (contentType != null) {
return MediaType.MULTIPART_FORM_DATA.includes(contentType);
return MULTIPART_ALL.includes(contentType);
}
for (List<?> values : map.values()) {
for (Object value : values) {
Expand Down Expand Up @@ -368,19 +422,26 @@ protected String serializeForm(MultiValueMap<String, Object> formData, Charset c
return builder.toString();
}

private void writeMultipart(final MultiValueMap<String, Object> parts, HttpOutputMessage outputMessage)
private void writeMultipart(final MultiValueMap<String, Object> parts, MediaType contentType, HttpOutputMessage outputMessage)
throws IOException {

// If the supplied content type is null, fall back to multipart/form-data.
// Otherwise rely on the fact that isMultipart() already verified the
// supplied content type is multipart/*.
if (contentType == null) {
contentType = MediaType.MULTIPART_FORM_DATA;
}

final byte[] boundary = generateMultipartBoundary();
Map<String, String> parameters = new LinkedHashMap<>(2);
if (!isFilenameCharsetSet()) {
parameters.put("charset", this.charset.name());
}
parameters.put("boundary", new String(boundary, StandardCharsets.US_ASCII));

MediaType contentType = new MediaType(MediaType.MULTIPART_FORM_DATA, parameters);
HttpHeaders headers = outputMessage.getHeaders();
headers.setContentType(contentType);
// Add parameters to output content type
contentType = new MediaType(contentType, parameters);
outputMessage.getHeaders().setContentType(contentType);

if (outputMessage instanceof StreamingHttpOutputMessage) {
StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) outputMessage;
Expand Down
Expand Up @@ -47,6 +47,9 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED;
import static org.springframework.http.MediaType.MULTIPART_FORM_DATA;
import static org.springframework.http.MediaType.TEXT_XML;

/**
* Unit tests for {@link FormHttpMessageConverter} and
Expand All @@ -58,24 +61,46 @@
*/
public class FormHttpMessageConverterTests {

protected static final MediaType MULTIPART_MIXED = new MediaType("multipart", "mixed");
protected static final MediaType MULTIPART_RELATED = new MediaType("multipart", "related");
private static final MediaType MULTIPART_ALL = new MediaType("multipart", "*");
private static final MediaType MULTIPART_MIXED = new MediaType("multipart", "mixed");
private static final MediaType MULTIPART_RELATED = new MediaType("multipart", "related");

private final FormHttpMessageConverter converter = new AllEncompassingFormHttpMessageConverter();


@Test
public void canRead() {
assertThat(this.converter.canRead(MultiValueMap.class, MediaType.APPLICATION_FORM_URLENCODED)).isTrue();
assertThat(this.converter.canRead(MultiValueMap.class, MediaType.MULTIPART_FORM_DATA)).isFalse();
assertCanRead(MultiValueMap.class, null);
assertCanRead(APPLICATION_FORM_URLENCODED);

assertCannotRead(String.class, null);
assertCannotRead(String.class, APPLICATION_FORM_URLENCODED);
}

@Test
public void cannotReadMultipart() {
// Without custom multipart types supported
assertCannotRead(MULTIPART_ALL);
assertCannotRead(MULTIPART_FORM_DATA);
assertCannotRead(MULTIPART_MIXED);
assertCannotRead(MULTIPART_RELATED);

this.converter.addSupportedMediaTypes(MULTIPART_MIXED, MULTIPART_RELATED);

// With custom multipart types supported
assertCannotRead(MULTIPART_ALL);
assertCannotRead(MULTIPART_FORM_DATA);
assertCannotRead(MULTIPART_MIXED);
assertCannotRead(MULTIPART_RELATED);
}

@Test
public void canWrite() {
assertCanWrite(MediaType.APPLICATION_FORM_URLENCODED);
assertCanWrite(MediaType.MULTIPART_FORM_DATA);
assertCanWrite(APPLICATION_FORM_URLENCODED);
assertCanWrite(MULTIPART_FORM_DATA);
assertCanWrite(new MediaType("multipart", "form-data", StandardCharsets.UTF_8));
assertCanWrite(MediaType.ALL);
assertCanWrite(null);
}

@Test
Expand Down Expand Up @@ -103,14 +128,6 @@ public void addSupportedMediaTypes() {
assertCanWrite(MULTIPART_RELATED);
}

private void assertCanWrite(MediaType mediaType) {
assertThat(this.converter.canWrite(MultiValueMap.class, mediaType)).isTrue();
}

private void assertCannotWrite(MediaType mediaType) {
assertThat(this.converter.canWrite(MultiValueMap.class, mediaType)).isFalse();
}

@Test
public void readForm() throws Exception {
String body = "name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3";
Expand All @@ -136,7 +153,7 @@ public void writeForm() throws IOException {
body.add("name 2", "value 2+2");
body.add("name 3", null);
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();
this.converter.write(body, MediaType.APPLICATION_FORM_URLENCODED, outputMessage);
this.converter.write(body, APPLICATION_FORM_URLENCODED, outputMessage);

assertThat(outputMessage.getBodyAsString(StandardCharsets.UTF_8)).as("Invalid result").isEqualTo("name+1=value+1&name+2=value+2%2B1&name+2=value+2%2B2&name+3");
assertThat(outputMessage.getHeaders().getContentType().toString()).as("Invalid content-type").isEqualTo("application/x-www-form-urlencoded;charset=UTF-8");
Expand Down Expand Up @@ -165,7 +182,7 @@ public String getFilename() {

Source xml = new StreamSource(new StringReader("<root><child/></root>"));
HttpHeaders entityHeaders = new HttpHeaders();
entityHeaders.setContentType(MediaType.TEXT_XML);
entityHeaders.setContentType(TEXT_XML);
HttpEntity<Source> entity = new HttpEntity<>(xml, entityHeaders);
parts.add("xml", entity);

Expand Down Expand Up @@ -226,7 +243,7 @@ public void writeMultipartOrder() throws Exception {
parts.add("part1", myBean);

HttpHeaders entityHeaders = new HttpHeaders();
entityHeaders.setContentType(MediaType.TEXT_XML);
entityHeaders.setContentType(TEXT_XML);
HttpEntity<MyBean> entity = new HttpEntity<>(myBean, entityHeaders);
parts.add("part2", entity);

Expand Down Expand Up @@ -261,6 +278,32 @@ public void writeMultipartOrder() throws Exception {
.endsWith("><string>foo</string></MyBean>");
}

private void assertCanRead(MediaType mediaType) {
assertCanRead(MultiValueMap.class, mediaType);
}

private void assertCanRead(Class<?> clazz, MediaType mediaType) {
assertThat(this.converter.canRead(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isTrue();
}

private void assertCannotRead(MediaType mediaType) {
assertCannotRead(MultiValueMap.class, mediaType);
}

private void assertCannotRead(Class<?> clazz, MediaType mediaType) {
assertThat(this.converter.canRead(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isFalse();
}

private void assertCanWrite(MediaType mediaType) {
Class<?> clazz = MultiValueMap.class;
assertThat(this.converter.canWrite(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isTrue();
}

private void assertCannotWrite(MediaType mediaType) {
Class<?> clazz = MultiValueMap.class;
assertThat(this.converter.canWrite(clazz, mediaType)).as(clazz.getSimpleName() + " : " + mediaType).isFalse();
}


private static class MockHttpOutputMessageRequestContext implements RequestContext {

Expand Down
Expand Up @@ -35,13 +35,17 @@
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.http.HttpHeaders.LOCATION;
import static org.springframework.http.MediaType.MULTIPART_FORM_DATA;

/**
* @author Brian Clozel
* @author Sam Brannen
*/
public class AbstractMockWebServerTestCase {

protected static final MediaType MULTIPART_MIXED = new MediaType("multipart", "mixed");
protected static final MediaType MULTIPART_RELATED = new MediaType("multipart", "related");

protected static final MediaType textContentType =
new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8"));

Expand Down Expand Up @@ -120,10 +124,31 @@ private MockResponse jsonPostRequest(RecordedRequest request, String location, S
.setResponseCode(201);
}

private MockResponse multipartRequest(RecordedRequest request) {
MediaType mediaType = MediaType.parseMediaType(request.getHeader("Content-Type"));
assertThat(mediaType.isCompatibleWith(MediaType.MULTIPART_FORM_DATA)).isTrue();
private MockResponse multipartFormDataRequest(RecordedRequest request) {
MediaType mediaType = MediaType.parseMediaType(request.getHeader(CONTENT_TYPE));
assertThat(mediaType.isCompatibleWith(MULTIPART_FORM_DATA)).as(MULTIPART_FORM_DATA.toString()).isTrue();
assertMultipart(request, mediaType);
return new MockResponse().setResponseCode(200);
}

private MockResponse multipartMixedRequest(RecordedRequest request) {
MediaType mediaType = MediaType.parseMediaType(request.getHeader(CONTENT_TYPE));
assertThat(mediaType.isCompatibleWith(MULTIPART_MIXED)).as(MULTIPART_MIXED.toString()).isTrue();
assertMultipart(request, mediaType);
return new MockResponse().setResponseCode(200);
}

private MockResponse multipartRelatedRequest(RecordedRequest request) {
MediaType mediaType = MediaType.parseMediaType(request.getHeader(CONTENT_TYPE));
assertThat(mediaType.isCompatibleWith(MULTIPART_RELATED)).as(MULTIPART_RELATED.toString()).isTrue();
assertMultipart(request, mediaType);
return new MockResponse().setResponseCode(200);
}

private void assertMultipart(RecordedRequest request, MediaType mediaType) {
assertThat(mediaType.isCompatibleWith(new MediaType("multipart", "*"))).as("multipart/*").isTrue();
String boundary = mediaType.getParameter("boundary");
assertThat(boundary).as("boundary").isNotBlank();
Buffer body = request.getBody();
try {
assertPart(body, "form-data", boundary, "name 1", "text/plain", "value 1");
Expand All @@ -132,9 +157,8 @@ private MockResponse multipartRequest(RecordedRequest request) {
assertFilePart(body, "form-data", boundary, "logo", "logo.jpg", "image/jpeg");
}
catch (EOFException ex) {
throw new IllegalStateException(ex);
throw new AssertionError(ex);
}
return new MockResponse().setResponseCode(200);
}

private void assertPart(Buffer buffer, String disposition, String boundary, String name,
Expand Down Expand Up @@ -245,8 +269,14 @@ else if (request.getPath().equals("/status/server")) {
else if (request.getPath().contains("/uri/")) {
return new MockResponse().setBody(request.getPath()).setHeader(CONTENT_TYPE, "text/plain");
}
else if (request.getPath().equals("/multipart")) {
return multipartRequest(request);
else if (request.getPath().equals("/multipartFormData")) {
return multipartFormDataRequest(request);
}
else if (request.getPath().equals("/multipartMixed")) {
return multipartMixedRequest(request);
}
else if (request.getPath().equals("/multipartRelated")) {
return multipartRelatedRequest(request);
}
else if (request.getPath().equals("/form")) {
return formRequest(request);
Expand Down

0 comments on commit 5008423

Please sign in to comment.