diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index 504f3b98408..29d3e1b848f 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -427,12 +427,20 @@ RequestMatcher requestMatcher(ServletContext servletContext) { @Override public boolean matches(HttpServletRequest request) { - return this.requestMatcherFactory.apply(request.getServletContext()).matches(request); + ServletContext servletContext = request.getServletContext(); + if (servletContext == null) { + return false; + } + return this.requestMatcherFactory.apply(servletContext).matches(request); } @Override public MatchResult matcher(HttpServletRequest request) { - return this.requestMatcherFactory.apply(request.getServletContext()).matcher(request); + ServletContext servletContext = request.getServletContext(); + if (servletContext == null) { + return MatchResult.notMatch(); + } + return this.requestMatcherFactory.apply(servletContext).matcher(request); } @Override diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index 3d5b25c7661..e829f4c453a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -42,6 +42,7 @@ import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; @@ -358,6 +359,19 @@ public void matchesWhenNoMappingThenException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> requestMatcher.matcher(request)); } + @Test + public void requestMatchersWhenServletContextCanBeNullThenDisallow() { + TestDeferredRequestMatcherRegistry deferredRequestMatcherRegistry = new TestDeferredRequestMatcherRegistry(); + deferredRequestMatcherRegistry.setApplicationContext(this.context); + List requestMatchers = deferredRequestMatcherRegistry.requestMatchers("/**"); + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/endpoint"); + + ReflectionTestUtils.setField(request, "servletContext", null); + + assertThat(requestMatchers).isNotEmpty().hasSize(1); + assertThat(requestMatchers.get(0).matches(request)).isFalse(); + } + private void mockMvcIntrospector(boolean isPresent) { ApplicationContext context = this.matcherRegistry.getApplicationContext(); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); @@ -392,6 +406,16 @@ private List unwrap(List wrappedMatchers) { } + private static class TestDeferredRequestMatcherRegistry + extends AbstractRequestMatcherRegistry> { + + @Override + protected List chainRequestMatchers(List requestMatchers) { + return requestMatchers; + } + + } + @Configuration @EnableWebSecurity @EnableWebMvc diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java index 26874c31be5..cb99c74d865 100644 --- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java +++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.Enumeration; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -178,6 +179,8 @@ static class DummyRequest extends HttpServletRequestWrapper { private final Map parameters = new LinkedHashMap<>(); + private final Map attributes = new LinkedHashMap<>(); + DummyRequest() { super(UNSUPPORTED_REQUEST); } @@ -189,7 +192,7 @@ public String getCharacterEncoding() { @Override public Object getAttribute(String attributeName) { - return null; + return this.attributes.get(attributeName); } void setRequestURI(String requestURI) { @@ -317,6 +320,23 @@ void setServletContext(ServletContext servletContext) { this.servletContext = servletContext; } + @Override + public void setAttribute(String name, Object value) { + Assert.notNull(name, "name can not be null"); + this.attributes.put(name, value); + } + + @Override + public void removeAttribute(String name) { + Assert.notNull(name, "name can not be null"); + this.attributes.remove(name); + } + + @Override + public Enumeration getAttributeNames() { + return Collections.enumeration(new LinkedHashSet<>(this.attributes.keySet())); + } + } static final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler { diff --git a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java index 98dafa27fc3..6f5715c850b 100644 --- a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java @@ -163,4 +163,15 @@ public void testDummyRequestGetHeadersNull() { assertThatExceptionOfType(NoSuchElementException.class).isThrownBy(headers::nextElement); } + @Test + public void testDummyRequestGetAttribute() { + DummyRequest request = new DummyRequest(); + request.setAttribute("name", "value"); + request.setAttribute("removeName", "removeValue"); + request.removeAttribute("removeName"); + Enumeration attributeNames = request.getAttributeNames(); + assertThat(attributeNames.nextElement()).isEqualTo("name"); + assertThatExceptionOfType(NoSuchElementException.class).isThrownBy(attributeNames::nextElement); + } + }