diff --git a/build.gradle b/build.gradle index d67ab1471..f9db85d35 100644 --- a/build.gradle +++ b/build.gradle @@ -61,6 +61,7 @@ configure(allprojects) { testCompile("org.easymock:easymock:3.1") testCompile("xmlunit:xmlunit:1.5") testCompile("com.sun.mail:javax.mail:1.5.4") + testCompile("commons-io:commons-io:2.5") testRuntime("org.codehaus.woodstox:woodstox-core-asl:4.2.0") } diff --git a/spring-ws-core/src/main/java/org/springframework/ws/transport/http/AbstractHttpSenderConnection.java b/spring-ws-core/src/main/java/org/springframework/ws/transport/http/AbstractHttpSenderConnection.java index 8cfa03c53..38507c265 100644 --- a/spring-ws-core/src/main/java/org/springframework/ws/transport/http/AbstractHttpSenderConnection.java +++ b/spring-ws-core/src/main/java/org/springframework/ws/transport/http/AbstractHttpSenderConnection.java @@ -16,14 +16,13 @@ package org.springframework.ws.transport.http; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.PushbackInputStream; import java.util.Iterator; import java.util.zip.GZIPInputStream; import javax.xml.namespace.QName; -import org.springframework.util.FileCopyUtils; import org.springframework.util.StringUtils; import org.springframework.ws.transport.AbstractSenderConnection; import org.springframework.ws.transport.FaultAwareWebServiceConnection; @@ -33,13 +32,21 @@ * Abstract base class for {@link WebServiceConnection} implementations that send request over HTTP. * * @author Arjen Poutsma + * @author Andreas Veithen * @since 1.0.0 */ public abstract class AbstractHttpSenderConnection extends AbstractSenderConnection implements FaultAwareWebServiceConnection { - /** Buffer used for reading the response, when the content length is invalid. */ - private byte[] responseBuffer; + /** + * Cached result of {@link #hasResponse}. + */ + private Boolean hasResponse; + + /** + * The raw response input stream to use instead of calling {@link #getRawResponseInputStream()}. + */ + private PushbackInputStream rawResponseInputStream; @Override public final boolean hasError() throws IOException { @@ -69,23 +76,31 @@ protected final boolean hasResponse() throws IOException { HttpTransportConstants.STATUS_NO_CONTENT == responseCode) { return false; } + if (hasResponse != null) { + return hasResponse; + } long contentLength = getResponseContentLength(); if (contentLength < 0) { - if (responseBuffer == null) { - responseBuffer = FileCopyUtils.copyToByteArray(getRawResponseInputStream()); + rawResponseInputStream = new PushbackInputStream(getRawResponseInputStream()); + int b = rawResponseInputStream.read(); + if (b == -1) { + hasResponse = Boolean.FALSE; + } + else { + hasResponse = Boolean.TRUE; + rawResponseInputStream.unread(b); } - contentLength = responseBuffer.length; } - return contentLength > 0; + else { + hasResponse = contentLength > 0; + } + return hasResponse; } @Override protected final InputStream getResponseInputStream() throws IOException { - InputStream inputStream; - if (responseBuffer != null) { - inputStream = new ByteArrayInputStream(responseBuffer); - } - else { + InputStream inputStream = rawResponseInputStream; + if (inputStream == null) { inputStream = getRawResponseInputStream(); } return isGzipResponse() ? new GZIPInputStream(inputStream) : inputStream; diff --git a/spring-ws-core/src/test/java/org/springframework/ws/transport/http/AbstractHttpSenderConnectionTest.java b/spring-ws-core/src/test/java/org/springframework/ws/transport/http/AbstractHttpSenderConnectionTest.java new file mode 100644 index 000000000..b369350ed --- /dev/null +++ b/spring-ws-core/src/test/java/org/springframework/ws/transport/http/AbstractHttpSenderConnectionTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ws.transport.http; + +import static org.easymock.EasyMock.*; +import static org.junit.Assert.*; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Collections; +import java.util.Random; + +import org.apache.commons.io.IOUtils; +import org.apache.commons.io.input.CountingInputStream; +import org.easymock.Capture; +import org.junit.Test; +import org.springframework.ws.WebServiceMessage; +import org.springframework.ws.WebServiceMessageFactory; + +/** + * @author Andreas Veithen + */ +public class AbstractHttpSenderConnectionTest { + /** + * Tests that {@link AbstractHttpSenderConnection} doesn't consume the response stream before + * passing it to the message factory. This is a regression test for SWS-707. + * + * @param chunking + * Specifies whether the test should simulate a response with chunking enabled. + * @throws Exception + */ + private void testSupportsStreaming(boolean chunking) throws Exception { + byte[] content = new byte[16*1024]; + new Random().nextBytes(content); + CountingInputStream rawInputStream = new CountingInputStream(new ByteArrayInputStream(content)); + + AbstractHttpSenderConnection connection = createNiceMock(AbstractHttpSenderConnection.class); + expect(connection.getResponseCode()).andReturn(200); + // Simulate response with chunking enabled + expect(connection.getResponseContentLength()).andReturn(chunking ? -1L : content.length); + expect(connection.getRawResponseInputStream()).andReturn(rawInputStream); + expect(connection.getResponseHeaders(anyObject())).andReturn(Collections.emptyIterator()); + + // Create a mock message factory to capture the InputStream passed to it + WebServiceMessageFactory messageFactory = createNiceMock(WebServiceMessageFactory.class); + WebServiceMessage message = createNiceMock(WebServiceMessage.class); + Capture inputStreamCapture = new Capture<>(); + expect(messageFactory.createWebServiceMessage(capture(inputStreamCapture))).andReturn(message); + + replay(connection, messageFactory, message); + + connection.receive(messageFactory); + + assertTrue("The raw input stream has been completely consumed", + rawInputStream.getCount() < content.length); + assertArrayEquals("Unexpected content received by the message factory", + content, IOUtils.toByteArray(inputStreamCapture.getValue())); + } + + @Test + public void testSupportsStreamingWithChunkingEnabled() throws Exception { + testSupportsStreaming(true); + } + + @Test + public void testSupportsStreamingWithChunkingDisabled() throws Exception { + testSupportsStreaming(false); + } +}