Skip to content

Commit

Permalink
Fix HttpSecurity.addFilter* Ordering
Browse files Browse the repository at this point in the history
Closes gh-9633
  • Loading branch information
rwinch committed Apr 14, 2021
1 parent 67d5520 commit 26788a7
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.security.config.annotation.web.builders;

import java.io.Serializable;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -46,7 +45,6 @@
import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter;
import org.springframework.security.web.session.ConcurrentSessionFilter;
import org.springframework.security.web.session.SessionManagementFilter;
import org.springframework.util.Assert;
import org.springframework.web.filter.CorsFilter;

/**
Expand All @@ -58,15 +56,15 @@
*/

@SuppressWarnings("serial")
final class FilterComparator implements Comparator<Filter>, Serializable {
final class FilterOrderRegistration {

private static final int INITIAL_ORDER = 100;

private static final int ORDER_STEP = 100;

private final Map<String, Integer> filterToOrder = new HashMap<>();

FilterComparator() {
FilterOrderRegistration() {
Step order = new Step(INITIAL_ORDER, ORDER_STEP);
put(ChannelProcessingFilter.class, order.next());
order.next(); // gh-8105
Expand Down Expand Up @@ -114,60 +112,6 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
put(SwitchUserFilter.class, order.next());
}

@Override
public int compare(Filter lhs, Filter rhs) {
Integer left = getOrder(lhs.getClass());
Integer right = getOrder(rhs.getClass());
return left - right;
}

/**
* Determines if a particular {@link Filter} is registered to be sorted
* @param filter
* @return
*/
boolean isRegistered(Class<? extends Filter> filter) {
return getOrder(filter) != null;
}

/**
* Registers a {@link Filter} to exist after a particular {@link Filter} that is
* already registered.
* @param filter the {@link Filter} to register
* @param afterFilter the {@link Filter} that is already registered and that
* {@code filter} should be placed after.
*/
void registerAfter(Class<? extends Filter> filter, Class<? extends Filter> afterFilter) {
Integer position = getOrder(afterFilter);
Assert.notNull(position, () -> "Cannot register after unregistered Filter " + afterFilter);
put(filter, position + 1);
}

/**
* Registers a {@link Filter} to exist at a particular {@link Filter} position
* @param filter the {@link Filter} to register
* @param atFilter the {@link Filter} that is already registered and that
* {@code filter} should be placed at.
*/
void registerAt(Class<? extends Filter> filter, Class<? extends Filter> atFilter) {
Integer position = getOrder(atFilter);
Assert.notNull(position, () -> "Cannot register after unregistered Filter " + atFilter);
put(filter, position);
}

/**
* Registers a {@link Filter} to exist before a particular {@link Filter} that is
* already registered.
* @param filter the {@link Filter} to register
* @param beforeFilter the {@link Filter} that is already registered and that
* {@code filter} should be placed before.
*/
void registerBefore(Class<? extends Filter> filter, Class<? extends Filter> beforeFilter) {
Integer position = getOrder(beforeFilter);
Assert.notNull(position, () -> "Cannot register after unregistered Filter " + beforeFilter);
put(filter, position - 1);
}

private void put(Class<? extends Filter> filter, int position) {
String className = filter.getName();
this.filterToOrder.put(className, position);
Expand All @@ -179,7 +123,7 @@ private void put(Class<? extends Filter> filter, int position) {
* @param clazz the {@link Filter} class to determine the sort order
* @return the sort order or null if not defined
*/
private Integer getOrder(Class<?> clazz) {
Integer getOrder(Class<?> clazz) {
while (clazz != null) {
Integer result = this.filterToOrder.get(clazz.getName());
if (result != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@

package org.springframework.security.config.annotation.web.builders;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import org.springframework.context.ApplicationContext;
import org.springframework.core.OrderComparator;
import org.springframework.core.Ordered;
import org.springframework.http.HttpMethod;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationProvider;
Expand Down Expand Up @@ -127,11 +134,11 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul

private final RequestMatcherConfigurer requestMatcherConfigurer;

private List<Filter> filters = new ArrayList<>();
private List<OrderedFilter> filters = new ArrayList<>();

private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;

private FilterComparator comparator = new FilterComparator();
private FilterOrderRegistration filterOrders = new FilterOrderRegistration();

/**
* Creates a new instance
Expand Down Expand Up @@ -2522,8 +2529,12 @@ protected void beforeConfigure() throws Exception {

@Override
protected DefaultSecurityFilterChain performBuild() {
this.filters.sort(this.comparator);
return new DefaultSecurityFilterChain(this.requestMatcher, this.filters);
this.filters.sort(OrderComparator.INSTANCE);
List<Filter> sortedFilters = new ArrayList<>(this.filters.size());
for (Filter filter : this.filters) {
sortedFilters.add(((OrderedFilter) filter).filter);
}
return new DefaultSecurityFilterChain(this.requestMatcher, sortedFilters);
}

@Override
Expand All @@ -2544,24 +2555,28 @@ private AuthenticationManagerBuilder getAuthenticationRegistry() {

@Override
public HttpSecurity addFilterAfter(Filter filter, Class<? extends Filter> afterFilter) {
this.comparator.registerAfter(filter.getClass(), afterFilter);
return addFilter(filter);
return addFilterAtOffsetOf(filter, 1, afterFilter);
}

@Override
public HttpSecurity addFilterBefore(Filter filter, Class<? extends Filter> beforeFilter) {
this.comparator.registerBefore(filter.getClass(), beforeFilter);
return addFilter(filter);
return addFilterAtOffsetOf(filter, -1, beforeFilter);
}

private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
int order = this.filterOrders.getOrder(registeredFilter) + offset;
this.filters.add(new OrderedFilter(filter, order));
return this;
}

@Override
public HttpSecurity addFilter(Filter filter) {
Class<? extends Filter> filterClass = filter.getClass();
if (!this.comparator.isRegistered(filterClass)) {
throw new IllegalArgumentException("The Filter class " + filterClass.getName()
Integer order = this.filterOrders.getOrder(filter.getClass());
if (order == null) {
throw new IllegalArgumentException("The Filter class " + filter.getClass().getName()
+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
}
this.filters.add(filter);
this.filters.add(new OrderedFilter(filter, order));
return this;
}

Expand All @@ -2584,8 +2599,7 @@ public HttpSecurity addFilter(Filter filter) {
* @return the {@link HttpSecurity} for further customizations
*/
public HttpSecurity addFilterAt(Filter filter, Class<? extends Filter> atFilter) {
this.comparator.registerAt(filter.getClass(), atFilter);
return addFilter(filter);
return addFilterAtOffsetOf(filter, 0, atFilter);
}

/**
Expand Down Expand Up @@ -2973,4 +2987,37 @@ public HttpSecurity and() {

}

/**
* A Filter that implements Ordered to be sorted. After sorting occurs, the original
* filter is what is used by FilterChainProxy
*/
private static final class OrderedFilter implements Ordered, Filter {

private final Filter filter;

private final int order;

private OrderedFilter(Filter filter, int order) {
this.filter = filter;
this.order = order;
}

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
this.filter.doFilter(servletRequest, servletResponse, filterChain);
}

@Override
public int getOrder() {
return this.order;
}

@Override
public String toString() {
return "OrderedFilter{" + "filter=" + this.filter + ", order=" + this.order + '}';
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.builders;

import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;

import org.assertj.core.api.ListAssert;
import org.junit.Rule;
import org.junit.Test;

import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.access.ExceptionTranslationFilter;
import org.springframework.security.web.access.channel.ChannelProcessingFilter;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;

import static org.assertj.core.api.Assertions.assertThat;

public class HttpSecurityAddFilterTest {

@Rule
public final SpringTestRule spring = new SpringTestRule();

@Test
public void addFilterAfterWhenSameFilterDifferentPlacesThenOrderCorrect() {
this.spring.register(MyFilterMultipleAfterConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
ExceptionTranslationFilter.class, MyFilter.class);
}

@Test
public void addFilterBeforeWhenSameFilterDifferentPlacesThenOrderCorrect() {
this.spring.register(MyFilterMultipleBeforeConfig.class).autowire();

assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
ExceptionTranslationFilter.class);
}

@Test
public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() {
this.spring.register(MyFilterMultipleAtConfig.class).autowire();

assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
ExceptionTranslationFilter.class);
}

private ListAssert<Class<?>> assertThatFilters() {
FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
.collect(Collectors.toList());
return assertThat(filters);
}

public static class MyFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

@EnableWebSecurity
static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {

@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class);
// @formatter:on
}

}

@EnableWebSecurity
static class MyFilterMultipleBeforeConfig extends WebSecurityConfigurerAdapter {

@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterBefore(new MyFilter(), ExceptionTranslationFilter.class);
// @formatter:on
}

}

@EnableWebSecurity
static class MyFilterMultipleAtConfig extends WebSecurityConfigurerAdapter {

@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAt(new MyFilter(), ChannelProcessingFilter.class)
.addFilterAt(new MyFilter(), UsernamePasswordAuthenticationFilter.class);
// @formatter:on
}

}

}

0 comments on commit 26788a7

Please sign in to comment.