diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java index 189bf9b21884..6615dd8cfdf8 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverter.java @@ -17,11 +17,13 @@ package org.springframework.http.converter.xml; import java.io.IOException; + import javax.xml.transform.Result; import javax.xml.transform.Source; import org.springframework.beans.TypeMismatchException; import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageNotReadableException; import org.springframework.http.converter.HttpMessageNotWritableException; import org.springframework.oxm.Marshaller; @@ -50,7 +52,6 @@ public class MarshallingHttpMessageConverter extends AbstractXmlHttpMessageConve private Unmarshaller unmarshaller; - /** * Construct a new {@code MarshallingHttpMessageConverter} with no {@link Marshaller} or * {@link Unmarshaller} set. The Marshaller and Unmarshaller must be set after construction @@ -88,7 +89,6 @@ public MarshallingHttpMessageConverter(Marshaller marshaller, Unmarshaller unmar this.unmarshaller = unmarshaller; } - /** * Set the {@link Marshaller} to be used by this message converter. */ @@ -103,10 +103,24 @@ public void setUnmarshaller(Unmarshaller unmarshaller) { this.unmarshaller = unmarshaller; } + @Override + public boolean canRead(Class clazz, MediaType mediaType) { + Assert.notNull(this.unmarshaller, "Property 'unmarshaller' is required"); + + return canRead(mediaType) && unmarshaller.supports(clazz); + } @Override - public boolean supports(Class clazz) { - return this.unmarshaller.supports(clazz); + public boolean canWrite(Class clazz, MediaType mediaType) { + Assert.notNull(this.marshaller, "Property 'marshaller' is required"); + + return canWrite(mediaType) && marshaller.supports(clazz); + } + + @Override + protected boolean supports(Class clazz) { + // should not be called, since we override canRead()/canWrite() + throw new UnsupportedOperationException(); } @Override @@ -131,8 +145,7 @@ protected void writeToResult(Object o, HttpHeaders headers, Result result) throw this.marshaller.marshal(o, result); } catch (MarshallingFailureException ex) { - throw new HttpMessageNotWritableException("Could not write [" + o + "]", ex); + throw new HttpMessageNotWritableException("Could not write [" + o.getClass() + "]", ex); } } - } diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java index 75de8e40199c..93d143488e8d 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/MarshallingHttpMessageConverterTests.java @@ -16,36 +16,69 @@ package org.springframework.http.converter.xml; -import javax.xml.transform.stream.StreamResult; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import javax.xml.transform.Result; import javax.xml.transform.stream.StreamSource; -import org.junit.Before; import org.junit.Test; +import org.springframework.beans.TypeMismatchException; import org.springframework.http.MediaType; import org.springframework.http.MockHttpInputMessage; import org.springframework.http.MockHttpOutputMessage; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; import org.springframework.oxm.Marshaller; +import org.springframework.oxm.MarshallingFailureException; import org.springframework.oxm.Unmarshaller; - -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; +import org.springframework.oxm.UnmarshallingFailureException; /** + * Tests for {@link MarshallingHttpMessageConverter}. + * * @author Arjen Poutsma */ public class MarshallingHttpMessageConverterTests { - private MarshallingHttpMessageConverter converter; + @Test + public void canRead() throws Exception { + Unmarshaller unmarshaller = mock(Unmarshaller.class); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + + converter.setUnmarshaller(unmarshaller); + + when(unmarshaller.supports(Integer.class)).thenReturn(false); + when(unmarshaller.supports(String.class)).thenReturn(true); + + assertFalse(converter.canRead(Boolean.class, MediaType.TEXT_PLAIN)); + assertFalse(converter.canRead(Integer.class, MediaType.TEXT_XML)); + assertTrue(converter.canRead(String.class, MediaType.TEXT_XML)); + } + + @Test + public void canWrite() throws Exception { + Marshaller marshaller = mock(Marshaller.class); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); - private Marshaller marshaller; + converter.setMarshaller(marshaller); - private Unmarshaller unmarshaller; + when(marshaller.supports(Integer.class)).thenReturn(false); + when(marshaller.supports(String.class)).thenReturn(true); - @Before - public void setUp() { - marshaller = mock(Marshaller.class); - unmarshaller = mock(Unmarshaller.class); - converter = new MarshallingHttpMessageConverter(marshaller, unmarshaller); + assertFalse(converter.canWrite(Boolean.class, MediaType.TEXT_PLAIN)); + assertFalse(converter.canWrite(Integer.class, MediaType.TEXT_XML)); + assertTrue(converter.canWrite(String.class, MediaType.TEXT_XML)); } @Test @@ -53,20 +86,92 @@ public void read() throws Exception { String body = "Hello World"; MockHttpInputMessage inputMessage = new MockHttpInputMessage(body.getBytes("UTF-8")); - given(unmarshaller.unmarshal(isA(StreamSource.class))).willReturn(body); + Unmarshaller unmarshaller = mock(Unmarshaller.class); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + + converter.setUnmarshaller(unmarshaller); + + when(unmarshaller.unmarshal(isA(StreamSource.class))).thenReturn(body); String result = (String) converter.read(Object.class, inputMessage); assertEquals("Invalid result", body, result); } + @Test(expected = TypeMismatchException.class) + public void readWithTypeMismatchException() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(new byte[0]); + + Marshaller marshaller = mock(Marshaller.class); + Unmarshaller unmarshaller = mock(Unmarshaller.class); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller, unmarshaller); + + when(unmarshaller.unmarshal(isA(StreamSource.class))).thenReturn(Integer.valueOf(3)); + + converter.read(String.class, inputMessage); + } + + @Test + public void readWithMarshallingFailureException() throws Exception { + MockHttpInputMessage inputMessage = new MockHttpInputMessage(new byte[0]); + UnmarshallingFailureException ex = new UnmarshallingFailureException("forced"); + + Unmarshaller unmarshaller = mock(Unmarshaller.class); + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(); + + converter.setUnmarshaller(unmarshaller); + + when(unmarshaller.unmarshal(isA(StreamSource.class))).thenThrow(ex); + + try { + converter.read(Object.class, inputMessage); + fail("HttpMessageNotReadableException should be thrown"); + } + catch (HttpMessageNotReadableException e) { + assertTrue("Invalid exception hierarchy", e.getCause() == ex); + } + } + @Test public void write() throws Exception { String body = "Hello World"; MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + Marshaller marshaller = mock(Marshaller.class); + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller); + + doNothing().when(marshaller).marshal(eq(body), isA(Result.class)); + converter.write(body, null, outputMessage); - assertEquals("Invalid content-type", new MediaType("application", "xml"), - outputMessage.getHeaders().getContentType()); - verify(marshaller).marshal(eq(body), isA(StreamResult.class)); + + assertEquals("Invalid content-type", new MediaType("application", "xml"), outputMessage.getHeaders() + .getContentType()); + } + + @Test + public void writeWithMarshallingFailureException() throws Exception { + String body = "Hello World"; + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + MarshallingFailureException ex = new MarshallingFailureException("forced"); + + Marshaller marshaller = mock(Marshaller.class); + + MarshallingHttpMessageConverter converter = new MarshallingHttpMessageConverter(marshaller); + + doThrow(ex).when(marshaller).marshal(eq(body), isA(Result.class)); + + try { + converter.write(body, null, outputMessage); + fail("HttpMessageNotWritableException should be thrown"); + } + catch (HttpMessageNotWritableException e) { + assertTrue("Invalid exception hierarchy", e.getCause() == ex); + } + } + + @Test(expected = UnsupportedOperationException.class) + public void supports() throws Exception { + new MarshallingHttpMessageConverter().supports(Object.class); } }