diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java
index 189bf9b21884..6615dd8cfdf8 100644
--- a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java
+++ b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java
@@ -17,11 +17,13 @@
package org.springframework.http.converter.xml;
import java.io.IOException;
+
import javax.xml.transform.Result;
import javax.xml.transform.Source;
import org.springframework.beans.TypeMismatchException;
import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.oxm.Marshaller;
@@ -50,7 +52,6 @@ public class MarshallingHttpMessageConverter extends AbstractXmlHttpMessageConve
private Unmarshaller unmarshaller;
-
/**
* Construct a new {@code MarshallingHttpMessageConverter} with no {@link Marshaller} or
* {@link Unmarshaller} set. The Marshaller and Unmarshaller must be set after construction
@@ -88,7 +89,6 @@ public MarshallingHttpMessageConverter(Marshaller marshaller, Unmarshaller unmar
this.unmarshaller = unmarshaller;
}
-
/**
* Set the {@link Marshaller} to be used by this message converter.
*/
@@ -103,10 +103,24 @@ public void setUnmarshaller(Unmarshaller unmarshaller) {
this.unmarshaller = unmarshaller;
}
+ @Override
+ public boolean canRead(Class> clazz, MediaType mediaType) {
+ Assert.notNull(this.unmarshaller, "Property 'unmarshaller' is required");
+
+ return canRead(mediaType) && unmarshaller.supports(clazz);
+ }
@Override
- public boolean supports(Class> clazz) {
- return this.unmarshaller.supports(clazz);
+ public boolean canWrite(Class> clazz, MediaType mediaType) {
+ Assert.notNull(this.marshaller, "Property 'marshaller' is required");
+
+ return canWrite(mediaType) && marshaller.supports(clazz);
+ }
+
+ @Override
+ protected boolean supports(Class> clazz) {
+ // should not be called, since we override canRead()/canWrite()
+ throw new UnsupportedOperationException();
}
@Override
@@ -131,8 +145,7 @@ protected void writeToResult(Object o, HttpHeaders headers, Result result) throw
this.marshaller.marshal(o, result);
}
catch (MarshallingFailureException ex) {
- throw new HttpMessageNotWritableException("Could not write [" + o + "]", ex);
+ throw new HttpMessageNotWritableException("Could not write [" + o.getClass() + "]", ex);
}
}
-
}
diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java
index 75de8e40199c..93d143488e8d 100644
--- a/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java
+++ b/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java
@@ -16,36 +16,69 @@
package org.springframework.http.converter.xml;
-import javax.xml.transform.stream.StreamResult;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Matchers.isA;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import javax.xml.transform.Result;
import javax.xml.transform.stream.StreamSource;
-import org.junit.Before;
import org.junit.Test;
+import org.springframework.beans.TypeMismatchException;
import org.springframework.http.MediaType;
import org.springframework.http.MockHttpInputMessage;
import org.springframework.http.MockHttpOutputMessage;
+import org.springframework.http.converter.HttpMessageNotReadableException;
+import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.oxm.Marshaller;
+import org.springframework.oxm.MarshallingFailureException;
import org.springframework.oxm.Unmarshaller;
-
-import static org.junit.Assert.*;
-import static org.mockito.BDDMockito.*;
+import org.springframework.oxm.UnmarshallingFailureException;
/**
+ * Tests for {@link MarshallingHttpMessageConverter}.
+ *
* @author Arjen Poutsma
*/
public class MarshallingHttpMessageConverterTests {
- private MarshallingHttpMessageConverter converter;
+ @Test
+ public void canRead() throws Exception {
+ Unmarshaller unmarshaller = mock(Unmarshaller.class);
+
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter();
+
+ converter.setUnmarshaller(unmarshaller);
+
+ when(unmarshaller.supports(Integer.class)).thenReturn(false);
+ when(unmarshaller.supports(String.class)).thenReturn(true);
+
+ assertFalse(converter.canRead(Boolean.class, MediaType.TEXT_PLAIN));
+ assertFalse(converter.canRead(Integer.class, MediaType.TEXT_XML));
+ assertTrue(converter.canRead(String.class, MediaType.TEXT_XML));
+ }
+
+ @Test
+ public void canWrite() throws Exception {
+ Marshaller marshaller = mock(Marshaller.class);
+
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter();
- private Marshaller marshaller;
+ converter.setMarshaller(marshaller);
- private Unmarshaller unmarshaller;
+ when(marshaller.supports(Integer.class)).thenReturn(false);
+ when(marshaller.supports(String.class)).thenReturn(true);
- @Before
- public void setUp() {
- marshaller = mock(Marshaller.class);
- unmarshaller = mock(Unmarshaller.class);
- converter = new MarshallingHttpMessageConverter(marshaller, unmarshaller);
+ assertFalse(converter.canWrite(Boolean.class, MediaType.TEXT_PLAIN));
+ assertFalse(converter.canWrite(Integer.class, MediaType.TEXT_XML));
+ assertTrue(converter.canWrite(String.class, MediaType.TEXT_XML));
}
@Test
@@ -53,20 +86,92 @@ public void read() throws Exception {
String body = "Hello World";
MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8"));
- given(unmarshaller.unmarshal(isA(StreamSource.class))).willReturn(body);
+ Unmarshaller unmarshaller = mock(Unmarshaller.class);
+
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter();
+
+ converter.setUnmarshaller(unmarshaller);
+
+ when(unmarshaller.unmarshal(isA(StreamSource.class))).thenReturn(body);
String result = (String) converter.read(Object.class, inputMessage);
assertEquals("Invalid result", body, result);
}
+ @Test(expected = TypeMismatchException.class)
+ public void readWithTypeMismatchException() throws Exception {
+ MockHttpInputMessage inputMessage = new MockHttpInputMessage(new byte[0]);
+
+ Marshaller marshaller = mock(Marshaller.class);
+ Unmarshaller unmarshaller = mock(Unmarshaller.class);
+
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller, unmarshaller);
+
+ when(unmarshaller.unmarshal(isA(StreamSource.class))).thenReturn(Integer.valueOf(3));
+
+ converter.read(String.class, inputMessage);
+ }
+
+ @Test
+ public void readWithMarshallingFailureException() throws Exception {
+ MockHttpInputMessage inputMessage = new MockHttpInputMessage(new byte[0]);
+ UnmarshallingFailureException ex = new UnmarshallingFailureException("forced");
+
+ Unmarshaller unmarshaller = mock(Unmarshaller.class);
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter();
+
+ converter.setUnmarshaller(unmarshaller);
+
+ when(unmarshaller.unmarshal(isA(StreamSource.class))).thenThrow(ex);
+
+ try {
+ converter.read(Object.class, inputMessage);
+ fail("HttpMessageNotReadableException should be thrown");
+ }
+ catch (HttpMessageNotReadableException e) {
+ assertTrue("Invalid exception hierarchy", e.getCause() == ex);
+ }
+ }
+
@Test
public void write() throws Exception {
String body = "Hello World";
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();
+ Marshaller marshaller = mock(Marshaller.class);
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller);
+
+ doNothing().when(marshaller).marshal(eq(body), isA(Result.class));
+
converter.write(body, null, outputMessage);
- assertEquals("Invalid content-type", new MediaType("application", "xml"),
- outputMessage.getHeaders().getContentType());
- verify(marshaller).marshal(eq(body), isA(StreamResult.class));
+
+ assertEquals("Invalid content-type", new MediaType("application", "xml"), outputMessage.getHeaders()
+ .getContentType());
+ }
+
+ @Test
+ public void writeWithMarshallingFailureException() throws Exception {
+ String body = "Hello World";
+ MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();
+ MarshallingFailureException ex = new MarshallingFailureException("forced");
+
+ Marshaller marshaller = mock(Marshaller.class);
+
+ MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller);
+
+ doThrow(ex).when(marshaller).marshal(eq(body), isA(Result.class));
+
+ try {
+ converter.write(body, null, outputMessage);
+ fail("HttpMessageNotWritableException should be thrown");
+ }
+ catch (HttpMessageNotWritableException e) {
+ assertTrue("Invalid exception hierarchy", e.getCause() == ex);
+ }
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void supports() throws Exception {
+ new MarshallingHttpMessageConverter().supports(Object.class);
}
}