Skip to content

Commit

Permalink
Revert "Cache Control only written if not set"
Browse files Browse the repository at this point in the history
This reverts commit 242b831.
Spring MVC fixed the issue we were working around and the changes
in Spring Security were unreliable.

Fixes gh-3975
  • Loading branch information
Rob Winch committed Oct 24, 2016
1 parent e62596f commit 57d7ad0
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 217 deletions.
Expand Up @@ -15,17 +15,15 @@
*/
package org.springframework.security.web.header;

import java.io.IOException;
import java.util.List;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.security.web.util.OnCommittedResponseWrapper;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.*;

/**
* Filter implementation to add headers to the current response. Can be useful to add
Expand Down Expand Up @@ -58,52 +56,12 @@ public HeaderWriterFilter(List<HeaderWriter> headerWriters) {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
throws ServletException, IOException {

HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request,
response, this.headerWriters);
try {
filterChain.doFilter(request, headerWriterResponse);
}
finally {
headerWriterResponse.writeHeaders();
for (HeaderWriter headerWriter : headerWriters) {
headerWriter.writeHeaders(request, response);
}
filterChain.doFilter(request, response);
}

static class HeaderWriterResponse extends OnCommittedResponseWrapper {
private final HttpServletRequest request;
private final List<HeaderWriter> headerWriters;

HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response,
List<HeaderWriter> headerWriters) {
super(response);
this.request = request;
this.headerWriters = headerWriters;
}

/*
* (non-Javadoc)
*
* @see org.springframework.security.web.util.OnCommittedResponseWrapper#
* onResponseCommitted()
*/
@Override
protected void onResponseCommitted() {
writeHeaders();
this.disableOnResponseCommitted();
}

protected void writeHeaders() {
if (isDisableOnResponseCommitted()) {
return;
}
for (HeaderWriter headerWriter : this.headerWriters) {
headerWriter.writeHeaders(this.request, getHttpResponse());
}
}

private HttpServletResponse getHttpResponse() {
return (HttpServletResponse) getResponse();
}
}
}
Expand Up @@ -15,20 +15,14 @@
*/
package org.springframework.security.web.header.writers;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.security.web.header.Header;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.util.ReflectionUtils;

/**
* Inserts headers to prevent caching if no cache control headers have been specified.
* Specifically it adds the following headers:
* A {@link StaticHeadersWriter} that inserts headers to prevent caching. Specifically it
* adds the following headers:
* <ul>
* <li>Cache-Control: no-cache, no-store, max-age=0, must-revalidate</li>
* <li>Pragma: no-cache</li>
Expand All @@ -38,47 +32,21 @@
* @author Rob Winch
* @since 3.2
*/
public final class CacheControlHeadersWriter implements HeaderWriter {
private static final String EXPIRES = "Expires";
private static final String PRAGMA = "Pragma";
private static final String CACHE_CONTROL = "Cache-Control";

private final Method getHeaderMethod;

private final HeaderWriter delegate;
public final class CacheControlHeadersWriter extends StaticHeadersWriter {

/**
* Creates a new instance
*/
public CacheControlHeadersWriter() {
this.delegate = new StaticHeadersWriter(createHeaders());
this.getHeaderMethod = ReflectionUtils.findMethod(HttpServletResponse.class,
"getHeader", String.class);
}

@Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (hasHeader(response, CACHE_CONTROL) || hasHeader(response, EXPIRES)
|| hasHeader(response, PRAGMA)) {
return;
}
this.delegate.writeHeaders(request, response);
}

private boolean hasHeader(HttpServletResponse response, String headerName) {
if (this.getHeaderMethod == null) {
return false;
}
return ReflectionUtils.invokeMethod(this.getHeaderMethod, response,
headerName) != null;
super(createHeaders());
}

private static List<Header> createHeaders() {
List<Header> headers = new ArrayList<Header>(2);
headers.add(new Header(CACHE_CONTROL,
headers.add(new Header("Cache-Control",
"no-cache, no-store, max-age=0, must-revalidate"));
headers.add(new Header(PRAGMA, "no-cache"));
headers.add(new Header(EXPIRES, "0"));
headers.add(new Header("Pragma", "no-cache"));
headers.add(new Header("Expires", "0"));
return headers;
}
}
Expand Up @@ -15,32 +15,21 @@
*/
package org.springframework.security.web.header;

import java.io.IOException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verify;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;

import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.header.HeaderWriterFilter;

/**
* Tests for the {@code HeadersFilter}
Expand Down Expand Up @@ -71,8 +60,8 @@ public void constructorNullWriters() throws Exception {
@Test
public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
headerWriters.add(this.writer1);
headerWriters.add(this.writer2);
headerWriters.add(writer1);
headerWriters.add(writer2);

HeaderWriterFilter filter = new HeaderWriterFilter(headerWriters);

Expand All @@ -82,34 +71,9 @@ public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {

filter.doFilter(request, response, filterChain);

verify(this.writer1).writeHeaders(request, response);
verify(this.writer2).writeHeaders(request, response);
verify(writer1).writeHeaders(request, response);
verify(writer2).writeHeaders(request, response);
assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain
// continued
}

// gh-2953
@Test
public void headersDelayed() throws Exception {
HeaderWriterFilter filter = new HeaderWriterFilter(
Arrays.<HeaderWriter>asList(this.writer1));

MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();

filter.doFilter(request, response, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
verifyZeroInteractions(HeaderWriterFilterTests.this.writer1);

response.flushBuffer();

verify(HeaderWriterFilterTests.this.writer1).writeHeaders(
any(HttpServletRequest.class), any(HttpServletResponse.class));
}
});

verifyNoMoreInteractions(this.writer1);
}
}
Expand Up @@ -15,32 +15,19 @@
*/
package org.springframework.security.web.header.writers;

import java.util.Arrays;
import static org.assertj.core.api.Assertions.assertThat;

import javax.servlet.http.HttpServletResponse;
import java.util.Arrays;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest;
import org.powermock.modules.junit4.PowerMockRunner;

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.ReflectionUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when;
import static org.powermock.api.mockito.PowerMockito.spy;

/**
* @author Rob Winch
*
*/
@RunWith(PowerMockRunner.class)
@PrepareOnlyThisForTest(ReflectionUtils.class)
public class CacheControlHeadersWriterTests {

private MockHttpServletRequest request;
Expand All @@ -51,79 +38,20 @@ public class CacheControlHeadersWriterTests {

@Before
public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.writer = new CacheControlHeadersWriter();
request = new MockHttpServletRequest();
response = new MockHttpServletResponse();
writer = new CacheControlHeadersWriter();
}

@Test
public void writeHeaders() {
this.writer.writeHeaders(this.request, this.response);

assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
assertThat(this.response.getHeaderValues("Pragma"))
.isEqualTo(Arrays.asList("no-cache"));
assertThat(this.response.getHeaderValues("Expires"))
.isEqualTo(Arrays.asList("0"));
}

@Test
public void writeHeadersServlet25() {
spy(ReflectionUtils.class);
when(ReflectionUtils.findMethod(HttpServletResponse.class, "getHeader",
String.class)).thenReturn(null);
this.response = spy(this.response);
doThrow(NoSuchMethodError.class).when(this.response).getHeader(anyString());
this.writer = new CacheControlHeadersWriter();
writer.writeHeaders(request, response);

this.writer.writeHeaders(this.request, this.response);

assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
assertThat(response.getHeaderNames().size()).isEqualTo(3);
assertThat(response.getHeaderValues("Cache-Control")).isEqualTo(
Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
assertThat(this.response.getHeaderValues("Pragma"))
.isEqualTo(Arrays.asList("no-cache"));
assertThat(this.response.getHeaderValues("Expires"))
.isEqualTo(Arrays.asList("0"));
}

// gh-2953
@Test
public void writeHeadersDisabledIfCacheControl() {
this.response.setHeader("Cache-Control", "max-age: 123");

this.writer.writeHeaders(this.request, this.response);

assertThat(this.response.getHeaderNames()).hasSize(1);
assertThat(this.response.getHeaderValues("Cache-Control"))
.containsOnly("max-age: 123");
assertThat(this.response.getHeaderValue("Pragma")).isNull();
assertThat(this.response.getHeaderValue("Expires")).isNull();
}

@Test
public void writeHeadersDisabledIfPragma() {
this.response.setHeader("Pragma", "mock");

this.writer.writeHeaders(this.request, this.response);

assertThat(this.response.getHeaderNames()).hasSize(1);
assertThat(this.response.getHeaderValues("Pragma")).containsOnly("mock");
assertThat(this.response.getHeaderValue("Expires")).isNull();
assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
}

@Test
public void writeHeadersDisabledIfExpires() {
this.response.setHeader("Expires", "mock");

this.writer.writeHeaders(this.request, this.response);

assertThat(this.response.getHeaderNames()).hasSize(1);
assertThat(this.response.getHeaderValues("Expires")).containsOnly("mock");
assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
assertThat(this.response.getHeaderValue("Pragma")).isNull();
assertThat(response.getHeaderValues("Pragma")).isEqualTo(
Arrays.asList("no-cache"));
assertThat(response.getHeaderValues("Expires")).isEqualTo(Arrays.asList("0"));
}
}

0 comments on commit 57d7ad0

Please sign in to comment.