Skip to content

Commit

Permalink
Make sure ErrorPageFilter is only applied once per request
Browse files Browse the repository at this point in the history
Fixes gh-1257
  • Loading branch information
Dave Syer committed Jul 17, 2014
1 parent 0c52817 commit 4a33ab5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

/**
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
Expand Down Expand Up @@ -76,21 +77,28 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
private final Map<Class<?>, String> exceptions = new HashMap<Class<?>, String>();

private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>();

private final OncePerRequestFilter delegate = new OncePerRequestFilter(
) {

@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
ErrorPageFilter.this.doFilter(request, response, chain);
}

};

@Override
public void init(FilterConfig filterConfig) throws ServletException {
delegate.init(filterConfig);
}

@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
if (request instanceof HttpServletRequest
&& response instanceof HttpServletResponse) {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
else {
chain.doFilter(request, response);
}
delegate.doFilter(request, response, chain);
}

private void doFilter(HttpServletRequest request, HttpServletResponse response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

package org.springframework.boot.context.web;

import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

import java.io.IOException;

import javax.servlet.RequestDispatcher;
Expand All @@ -29,13 +34,10 @@
import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;

import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

/**
* Tests for {@link ErrorPageFilter}.
*
Expand Down Expand Up @@ -97,6 +99,21 @@ public void doFilter(ServletRequest request, ServletResponse response)
equalTo(400));
}

@Test
public void oncePerRequest() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
assertNotNull(request.getAttribute("FILTER.FILTERED"));
super.doFilter(request, response);
}
};
filter.init(new MockFilterConfig("FILTER"));
this.filter.doFilter(this.request, this.response, this.chain);
}

@Test
public void globalError() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
Expand Down

0 comments on commit 4a33ab5

Please sign in to comment.