Skip to content

Commit

Permalink
Close response on errors writing to the outputStream (#1372)
Browse files Browse the repository at this point in the history
Currently all errors during writing to the outputStream are ignored. So the
response is read to the end. This is a problem when the response is
infite or very long running (`hystrix.stream` for example) and results in the
socket never being closed and eating up threads (on zuul and the proxied
backend). With this commit the errors during writing are not ignored resulting
in closing the response.
  • Loading branch information
joshiste authored and spencergibb committed Feb 10, 2017
1 parent 919c8fa commit 3e7b196
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 24 deletions.
Expand Up @@ -17,6 +17,7 @@
package org.springframework.cloud.netflix.zuul.filters.post; package org.springframework.cloud.netflix.zuul.filters.post;


import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
Expand Down Expand Up @@ -184,28 +185,33 @@ else if (context.getResponseGZipped() && isGzipRequested) {
} }
} }
finally { finally {
/**
* Closing the wrapping InputStream itself has no effect on closing the underlying tcp connection since it's a wrapped stream. I guess for http
* keep-alive. When closing the wrapping stream it tries to reach the end of the current request, which is impossible for infinite http streams. So
* instead of closing the InputStream we close the HTTP response.
*
* @author Johannes Edmeier
*/
try { try {
if (is != null) { Object zuulResponse = RequestContext.getCurrentContext()
is.close(); .get("zuulResponse");
if (zuulResponse instanceof Closeable) {
((Closeable) zuulResponse).close();
} }
outStream.flush(); outStream.flush();
// The container will close the stream for us // The container will close the stream for us
} }
catch (IOException ex) { catch (IOException ex) {
log.warn("Error while sending response to client: " + ex.getMessage());
} }
} }
} }


private void writeResponse(InputStream zin, OutputStream out) throws Exception { private void writeResponse(InputStream zin, OutputStream out) throws Exception {
try { byte[] bytes = buffers.get();
byte[] bytes = buffers.get(); int bytesRead = -1;
int bytesRead = -1; while ((bytesRead = zin.read(bytes)) != -1) {
while ((bytesRead = zin.read(bytes)) != -1) { out.write(bytes, 0, bytesRead);
out.write(bytes, 0, bytesRead);
}
}
catch(IOException ioe) {
log.warn("Error while sending response to client: "+ioe.getMessage());
} }
} }


Expand Down
Expand Up @@ -218,6 +218,7 @@ protected String getVerb(HttpServletRequest request) {


protected void setResponse(ClientHttpResponse resp) protected void setResponse(ClientHttpResponse resp)
throws ClientException, IOException { throws ClientException, IOException {
RequestContext.getCurrentContext().set("zuulResponse", resp);
this.helper.setResponse(resp.getStatusCode().value(), this.helper.setResponse(resp.getStatusCode().value(),
resp.getBody() == null ? null : resp.getBody(), resp.getHeaders()); resp.getBody() == null ? null : resp.getBody(), resp.getHeaders());
} }
Expand Down
Expand Up @@ -42,10 +42,10 @@
import org.apache.http.HttpRequest; import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse; import org.apache.http.HttpResponse;
import org.apache.http.ProtocolException; import org.apache.http.ProtocolException;
import org.apache.http.client.HttpClient;
import org.apache.http.client.RedirectStrategy; import org.apache.http.client.RedirectStrategy;
import org.apache.http.client.config.CookieSpecs; import org.apache.http.client.config.CookieSpecs;
import org.apache.http.client.config.RequestConfig; import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPatch; import org.apache.http.client.methods.HttpPatch;
import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut; import org.apache.http.client.methods.HttpPut;
Expand Down Expand Up @@ -192,8 +192,8 @@ public Object run() {
this.helper.addIgnoredHeaders(); this.helper.addIgnoredHeaders();


try { try {
HttpResponse response = forward(this.httpClient, verb, uri, request, headers, CloseableHttpResponse response = forward(this.httpClient, verb, uri, request,
params, requestEntity); headers, params, requestEntity);
setResponse(response); setResponse(response);
} }
catch (Exception ex) { catch (Exception ex) {
Expand Down Expand Up @@ -278,8 +278,8 @@ public HttpUriRequest getRedirect(HttpRequest request,
}).build(); }).build();
} }


private HttpResponse forward(HttpClient httpclient, String verb, String uri, private CloseableHttpResponse forward(CloseableHttpClient httpclient, String verb,
HttpServletRequest request, MultiValueMap<String, String> headers, String uri, HttpServletRequest request, MultiValueMap<String, String> headers,
MultiValueMap<String, String> params, InputStream requestEntity) MultiValueMap<String, String> params, InputStream requestEntity)
throws Exception { throws Exception {
Map<String, Object> info = this.helper.debug(verb, uri, headers, params, Map<String, Object> info = this.helper.debug(verb, uri, headers, params,
Expand All @@ -301,7 +301,8 @@ private HttpResponse forward(HttpClient httpclient, String verb, String uri,
try { try {
log.debug(httpHost.getHostName() + " " + httpHost.getPort() + " " log.debug(httpHost.getHostName() + " " + httpHost.getPort() + " "
+ httpHost.getSchemeName()); + httpHost.getSchemeName());
HttpResponse zuulResponse = forwardRequest(httpclient, httpHost, httpRequest); CloseableHttpResponse zuulResponse = forwardRequest(httpclient, httpHost,
httpRequest);
this.helper.appendDebug(info, zuulResponse.getStatusLine().getStatusCode(), this.helper.appendDebug(info, zuulResponse.getStatusLine().getStatusCode(),
revertHeaders(zuulResponse.getAllHeaders())); revertHeaders(zuulResponse.getAllHeaders()));
return zuulResponse; return zuulResponse;
Expand Down Expand Up @@ -379,8 +380,8 @@ private Header[] convertHeaders(MultiValueMap<String, String> headers) {
return list.toArray(new BasicHeader[0]); return list.toArray(new BasicHeader[0]);
} }


private HttpResponse forwardRequest(HttpClient httpclient, HttpHost httpHost, private CloseableHttpResponse forwardRequest(CloseableHttpClient httpclient,
HttpRequest httpRequest) throws IOException { HttpHost httpHost, HttpRequest httpRequest) throws IOException {
return httpclient.execute(httpHost, httpRequest); return httpclient.execute(httpHost, httpRequest);
} }


Expand All @@ -407,6 +408,7 @@ private String getVerb(HttpServletRequest request) {
} }


private void setResponse(HttpResponse response) throws IOException { private void setResponse(HttpResponse response) throws IOException {
RequestContext.getCurrentContext().set("zuulResponse", response);
this.helper.setResponse(response.getStatusLine().getStatusCode(), this.helper.setResponse(response.getStatusLine().getStatusCode(),
response.getEntity() == null ? null : response.getEntity().getContent(), response.getEntity() == null ? null : response.getEntity().getContent(),
revertHeaders(response.getAllHeaders())); revertHeaders(response.getAllHeaders()));
Expand Down
Expand Up @@ -16,9 +16,26 @@


package org.springframework.cloud.netflix.zuul.filters.post; package org.springframework.cloud.netflix.zuul.filters.post;


import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.cloud.netflix.zuul.filters.support.FilterConstants.X_ZUUL_DEBUG_HEADER;
import static org.hamcrest.Matchers.equalTo;

import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;


import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;


import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
Expand All @@ -34,11 +51,6 @@
import com.netflix.zuul.context.Debug; import com.netflix.zuul.context.Debug;
import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.context.RequestContext;


import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.springframework.cloud.netflix.zuul.filters.support.FilterConstants.X_ZUUL_DEBUG_HEADER;

/** /**
* @author Spencer Gibb * @author Spencer Gibb
*/ */
Expand Down Expand Up @@ -96,6 +108,34 @@ public void runWithOriginContentLength() throws Exception {
assertThat("wrong origin content length", contentLength, equalTo("6")); assertThat("wrong origin content length", contentLength, equalTo("6"));
} }


@Test
public void closeResponseOutpusStreamError() throws Exception {
HttpServletResponse response = mock(HttpServletResponse.class);

RequestContext context = new RequestContext();
context.setRequest(new MockHttpServletRequest());
context.setResponse(response);
context.setResponseDataStream(new ByteArrayInputStream("Hello\n".getBytes("UTF-8")));
Closeable zuulResponse = mock(Closeable.class);
context.set("zuulResponse", zuulResponse);
RequestContext.testSetCurrentContext(context);

SendResponseFilter filter = new SendResponseFilter();

ServletOutputStream zuuloutputstream = mock(ServletOutputStream.class);
doThrow(new IOException("Response to client closed")).when(zuuloutputstream).write(isA(byte[].class), anyInt(), anyInt());

when(response.getOutputStream()).thenReturn(zuuloutputstream);

try {
filter.run();
} catch (UndeclaredThrowableException ex) {
assertThat(ex.getUndeclaredThrowable().getMessage(), is("Response to client closed"));
}

verify(zuulResponse).close();
}

private void runFilter(String characterEncoding, String content, boolean streamContent) throws Exception { private void runFilter(String characterEncoding, String content, boolean streamContent) throws Exception {
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
SendResponseFilter filter = createFilter(content, characterEncoding, response, streamContent); SendResponseFilter filter = createFilter(content, characterEncoding, response, streamContent);
Expand Down

0 comments on commit 3e7b196

Please sign in to comment.