Skip to content

Commit

Permalink
Add DispatcherServlet to Tests
Browse files Browse the repository at this point in the history
Issue gh-13551
  • Loading branch information
jzheaux committed Jul 17, 2023
1 parent df239b6 commit bb46a54
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

import javax.servlet.MultipartConfigElement;
import javax.servlet.Servlet;
import javax.servlet.ServletRegistration;
import javax.servlet.ServletSecurityElement;

import org.springframework.lang.NonNull;
import org.springframework.web.servlet.DispatcherServlet;

public class MockServletContext extends org.springframework.mock.web.MockServletContext {

private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();

public static MockServletContext mvc() {
MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
return servletContext;
}

@NonNull
@Override
public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class<? extends Servlet> clazz) {
ServletRegistration.Dynamic dynamic = new MockServletRegistration(servletName, clazz);
this.registrations.put(servletName, dynamic);
return dynamic;
}

@NonNull
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations() {
return this.registrations;
}

private static class MockServletRegistration implements ServletRegistration.Dynamic {

private final String name;

private final Class<?> clazz;

MockServletRegistration(String name, Class<?> clazz) {
this.name = name;
this.clazz = clazz;
}

@Override
public void setLoadOnStartup(int loadOnStartup) {

}

@Override
public Set<String> setServletSecurity(ServletSecurityElement constraint) {
return null;
}

@Override
public void setMultipartConfig(MultipartConfigElement multipartConfig) {

}

@Override
public void setRunAsRole(String roleName) {

}

@Override
public void setAsyncSupported(boolean isAsyncSupported) {

}

@Override
public Set<String> addMapping(String... urlPatterns) {
return null;
}

@Override
public Collection<String> getMappings() {
return null;
}

@Override
public String getRunAsRole() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public String getClassName() {
return this.clazz.getName();
}

@Override
public boolean setInitParameter(String name, String value) {
return false;
}

@Override
public String getInitParameter(String name) {
return null;
}

@Override
public Set<String> setInitParameters(Map<String, String> initParameters) {
return null;
}

@Override
public Map<String, String> getInitParameters() {
return null;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,18 @@

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.DispatcherType;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletRegistration;

import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
Expand Down Expand Up @@ -70,10 +66,8 @@ public <O> O postProcess(O object) {
public void setUp() {
this.matcherRegistry = new TestRequestMatcherRegistry();
this.context = mock(WebApplicationContext.class);
ServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
given(this.context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR);
given(this.context.getServletContext()).willReturn(servletContext);
given(this.context.getServletContext()).willReturn(MockServletContext.mvc());
this.matcherRegistry.setApplicationContext(this.context);
}

Expand Down Expand Up @@ -256,25 +250,4 @@ protected List<RequestMatcher> chainRequestMatchers(List<RequestMatcher> request

}

private static class MockServletContext extends org.springframework.mock.web.MockServletContext {

private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();

@NotNull
@Override
public ServletRegistration.Dynamic addServlet(@NotNull String servletName, Class<? extends Servlet> clazz) {
ServletRegistration.Dynamic dynamic = mock(ServletRegistration.Dynamic.class);
given(dynamic.getClassName()).willReturn(clazz.getName());
this.registrations.put(servletName, dynamic);
return dynamic;
}

@NotNull
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations() {
return this.registrations;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.access.hierarchicalroles.RoleHierarchy;
import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
Expand Down Expand Up @@ -75,7 +75,7 @@ public class AuthorizeRequestsTests {

@BeforeEach
public void setup() {
this.servletContext = spy(new MockServletContext());
this.servletContext = spy(MockServletContext.mvc());
this.request = new MockHttpServletRequest("GET", "");
this.request.setMethod("GET");
this.response = new MockHttpServletResponse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.LinkedHashMap;
import java.util.Map;

import javax.servlet.Servlet;
import javax.servlet.ServletRegistration;
import javax.servlet.http.HttpServletResponse;

import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -39,6 +34,7 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
Expand All @@ -52,15 +48,12 @@
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.config.annotation.PathMatchConfigurer;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.springframework.security.config.Customizer.withDefaults;

/**
Expand Down Expand Up @@ -240,9 +233,7 @@ public void securityMatchersWhenMultiMvcMatcherThenAllPathsAreDenied() throws Ex
public void loadConfig(Class<?>... configs) {
this.context = new AnnotationConfigWebApplicationContext();
this.context.register(configs);
MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
this.context.setServletContext(servletContext);
this.context.setServletContext(MockServletContext.mvc());
this.context.refresh();
this.context.getAutowireCapableBeanFactory().autowireBean(this);
}
Expand Down Expand Up @@ -573,25 +564,4 @@ public void configurePathMatch(PathMatchConfigurer configurer) {

}

private static class MockServletContext extends org.springframework.mock.web.MockServletContext {

private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();

@NotNull
@Override
public ServletRegistration.Dynamic addServlet(@NotNull String servletName, Class<? extends Servlet> clazz) {
ServletRegistration.Dynamic dynamic = mock(ServletRegistration.Dynamic.class);
given(dynamic.getClassName()).willReturn(clazz.getName());
this.registrations.put(servletName, dynamic);
return dynamic;
}

@NotNull
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations() {
return this.registrations;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
Expand Down Expand Up @@ -167,7 +167,7 @@ public void multiMvcMatchersConfig() throws Exception {
public void loadConfig(Class<?>... configs) {
this.context = new AnnotationConfigWebApplicationContext();
this.context.register(configs);
this.context.setServletContext(new MockServletContext());
this.context.setServletContext(MockServletContext.mvc());
this.context.refresh();
this.context.getAutowireCapableBeanFactory().autowireBean(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor;
import org.springframework.mock.web.MockServletConfig;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.config.BeanIds;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.util.InMemoryXmlWebApplicationContext;
import org.springframework.test.context.web.GenericXmlWebContextLoader;
import org.springframework.test.web.servlet.MockMvc;
Expand Down Expand Up @@ -129,15 +129,15 @@ private SpringTestContext addFilter(Filter filter) {

public ConfigurableWebApplicationContext getContext() {
if (!this.context.isRunning()) {
this.context.setServletContext(new MockServletContext());
this.context.setServletContext(MockServletContext.mvc());
this.context.setServletConfig(new MockServletConfig());
this.context.refresh();
}
return this.context;
}

public void autowire() {
this.context.setServletContext(new MockServletContext());
this.context.setServletContext(MockServletContext.mvc());
this.context.setServletConfig(new MockServletConfig());
for (Consumer<ConfigurableWebApplicationContext> postProcessor : this.postProcessors) {
postProcessor.accept(this.context);
Expand Down

0 comments on commit bb46a54

Please sign in to comment.