Skip to content

Commit

Permalink
Fix Adding Filter Relative to Custom Filter
Browse files Browse the repository at this point in the history
Closes gh-9787
  • Loading branch information
marcusdacoregio committed Jun 14, 2021
1 parent fe13b48 commit 53870ab
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 3 deletions.
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -114,8 +114,18 @@ final class FilterOrderRegistration {
put(SwitchUserFilter.class, order.next());
}

private void put(Class<? extends Filter> filter, int position) {
/**
* Register a {@link Filter} with its specific position. If the {@link Filter} was
* already registered before, the position previously defined is not going to be
* overriden
* @param filter the {@link Filter} to register
* @param position the position to associate with the {@link Filter}
*/
void put(Class<? extends Filter> filter, int position) {
String className = filter.getName();
if (this.filterToOrder.containsKey(className)) {
return;
}
this.filterToOrder.put(className, position);
}

Expand Down
Expand Up @@ -2653,6 +2653,7 @@ public HttpSecurity addFilterBefore(Filter filter, Class<? extends Filter> befor
private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
int order = this.filterOrders.getOrder(registeredFilter) + offset;
this.filters.add(new OrderedFilter(filter, order));
this.filterOrders.put(filter.getClass(), order);
return this;
}

Expand Down
@@ -0,0 +1,75 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.builders;

import java.io.IOException;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;

import org.junit.Test;

import org.springframework.security.web.access.channel.ChannelProcessingFilter;

import static org.assertj.core.api.Assertions.assertThat;

public class FilterOrderRegistrationTests {

private final FilterOrderRegistration filterOrderRegistration = new FilterOrderRegistration();

@Test
public void putWhenNewFilterThenInsertCorrect() {
int position = 153;
this.filterOrderRegistration.put(MyFilter.class, position);
Integer order = this.filterOrderRegistration.getOrder(MyFilter.class);
assertThat(order).isEqualTo(position);
}

@Test
public void putWhenCustomFilterAlreadyExistsThenDoesNotOverride() {
int position = 160;
this.filterOrderRegistration.put(MyFilter.class, position);
this.filterOrderRegistration.put(MyFilter.class, 173);
Integer order = this.filterOrderRegistration.getOrder(MyFilter.class);
assertThat(order).isEqualTo(position);
}

@Test
public void putWhenPredefinedFilterThenDoesNotOverride() {
int position = 100;
Integer predefinedFilterOrderBefore = this.filterOrderRegistration.getOrder(ChannelProcessingFilter.class);
this.filterOrderRegistration.put(MyFilter.class, position);
Integer myFilterOrder = this.filterOrderRegistration.getOrder(MyFilter.class);
Integer predefinedFilterOrderAfter = this.filterOrderRegistration.getOrder(ChannelProcessingFilter.class);
assertThat(myFilterOrder).isEqualTo(position);
assertThat(predefinedFilterOrderAfter).isEqualTo(predefinedFilterOrderBefore).isEqualTo(position);
}

static class MyFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

}
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,14 +30,18 @@
import org.junit.Rule;
import org.junit.Test;

import org.springframework.context.annotation.Bean;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.access.ExceptionTranslationFilter;
import org.springframework.security.web.access.channel.ChannelProcessingFilter;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
import org.springframework.security.web.header.HeaderWriterFilter;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -70,6 +74,46 @@ public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() {
ExceptionTranslationFilter.class);
}

@Test
public void addFilterAfterWhenAfterCustomFilterThenOrderCorrect() {
this.spring.register(MyOtherFilterRelativeToMyFilterAfterConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
MyOtherFilter.class);
}

@Test
public void addFilterBeforeWhenBeforeCustomFilterThenOrderCorrect() {
this.spring.register(MyOtherFilterRelativeToMyFilterBeforeConfig.class).autowire();

assertThatFilters().containsSubsequence(MyOtherFilter.class, MyFilter.class,
WebAsyncManagerIntegrationFilter.class);
}

@Test
public void addFilterAtWhenAtCustomFilterThenOrderCorrect() {
this.spring.register(MyOtherFilterRelativeToMyFilterAtConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
MyOtherFilter.class, SecurityContextPersistenceFilter.class);
}

@Test
public void addFilterBeforeWhenCustomFilterDifferentPlacesThenOrderCorrect() {
this.spring.register(MyOtherFilterBeforeToMyFilterMultipleAfterConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyOtherFilter.class,
MyFilter.class, ExceptionTranslationFilter.class);
}

@Test
public void addFilterBeforeAndAfterWhenCustomFiltersDifferentPlacesThenOrderCorrect() {
this.spring.register(MyAnotherFilterRelativeToMyCustomFiltersMultipleConfig.class).autowire();

assertThatFilters().containsSubsequence(HeaderWriterFilter.class, MyFilter.class, MyOtherFilter.class,
MyOtherFilter.class, MyAnotherFilter.class, MyFilter.class, ExceptionTranslationFilter.class);
}

private ListAssert<Class<?>> assertThatFilters() {
FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
Expand All @@ -87,6 +131,26 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo

}

static class MyOtherFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

static class MyAnotherFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

@EnableWebSecurity
static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {

Expand Down Expand Up @@ -129,4 +193,83 @@ protected void configure(HttpSecurity http) throws Exception {

}

@EnableWebSecurity
static class MyOtherFilterRelativeToMyFilterAfterConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAfter(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyOtherFilterRelativeToMyFilterBeforeConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterBefore(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyOtherFilterRelativeToMyFilterAtConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAt(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAt(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyOtherFilterBeforeToMyFilterMultipleAfterConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class)
.addFilterBefore(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyAnotherFilterRelativeToMyCustomFiltersMultipleConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), HeaderWriterFilter.class)
.addFilterBefore(new MyOtherFilter(), ExceptionTranslationFilter.class)
.addFilterAfter(new MyOtherFilter(), MyFilter.class)
.addFilterAt(new MyAnotherFilter(), MyOtherFilter.class)
.addFilterAfter(new MyFilter(), MyAnotherFilter.class);
// @formatter:on
return http.build();
}

}

}

0 comments on commit 53870ab

Please sign in to comment.