Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add servletPath support to AuthorizeHttpRequests #13857

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;
import jakarta.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
Expand Down Expand Up @@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (registrations.size() > 1) {
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
if (dispatcherServlet != null) {
if (registrations.size() == 1) {
return requestMatchers(createMvcMatchers(method, patterns).toArray(RequestMatcher[]::new));
}
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
if (dispatcherServlet != null) {
String mapping = dispatcherServlet.getMappings().iterator().next();
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
for (MvcRequestMatcher matcher : matchers) {
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
}

private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
Expand All @@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
if (registrations == null) {
return false;
}
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
for (ServletRegistration registration : registrations.values()) {
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
}
}
catch (ClassNotFoundException ex) {
return false;
if (isDispatcherServlet(registration)) {
return true;
}
}
return false;
}

private ServletRegistration requireOneRootDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration rootDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
continue;
}
if (registration.getMappings().size() > 1) {
return null;
}
if (!"/".equals(registration.getMappings().iterator().next())) {
return null;
}
rootDispatcherServlet = registration;
}
return rootDispatcherServlet;
}

private ServletRegistration requireOnlyPathMappedDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration pathDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
return null;
}
if (registration.getMappings().size() > 1) {
return null;
}
String mapping = registration.getMappings().iterator().next();
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
return null;
}
if (pathDispatcherServlet != null) {
return null;
}
pathDispatcherServlet = registration;
}
return pathDispatcherServlet;
}

private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}

private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
Expand Down Expand Up @@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {

}

static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {

private final AntPathRequestMatcher ant;

private final MvcRequestMatcher mvc;

private final ServletContext servletContext;

DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
ServletContext servletContext) {
this.ant = ant;
this.mvc = mvc;
this.servletContext = servletContext;
}

@Override
public boolean matches(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matches(request);
}
return this.ant.matches(request);
}

@Override
public MatchResult matcher(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matcher(request);
}
return this.ant.matcher(request);
}

private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.configurers;

import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
import org.springframework.security.web.util.matcher.RequestMatcher;

abstract class AbstractRequestMatcherBuilderRegistry<C> extends AbstractRequestMatcherRegistry<C> {

private final RequestMatcherBuilder builder;

AbstractRequestMatcherBuilderRegistry(ApplicationContext context) {
this(context, RequestMatcherBuilders.createDefault(context));
}

AbstractRequestMatcherBuilderRegistry(ApplicationContext context, RequestMatcherBuilder builder) {
setApplicationContext(context);
this.builder = builder;
}

@Override
public final C requestMatchers(String... patterns) {
return requestMatchers(null, patterns);
}

@Override
public final C requestMatchers(HttpMethod method, String... patterns) {
return requestMatchers(this.builder.matchers(method, patterns).toArray(RequestMatcher[]::new));
}

@Override
public final C requestMatchers(HttpMethod method) {
return requestMatchers(method, "/**");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.configurers;

import org.springframework.http.HttpMethod;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;

final class AntPathRequestMatcherBuilder implements RequestMatcherBuilder {

private final String servletPath;

private AntPathRequestMatcherBuilder(String servletPath) {
this.servletPath = servletPath;
}

static AntPathRequestMatcherBuilder absolute() {
return new AntPathRequestMatcherBuilder(null);
}

static AntPathRequestMatcherBuilder relativeTo(String path) {
return new AntPathRequestMatcherBuilder(path);
}

@Override
public AntPathRequestMatcher matcher(String pattern) {
return matcher((String) null, pattern);
}

@Override
public AntPathRequestMatcher matcher(HttpMethod method, String pattern) {
return matcher((method != null) ? method.name() : null, pattern);
}

private AntPathRequestMatcher matcher(String method, String pattern) {
return new AntPathRequestMatcher(prependServletPath(pattern), method);
}

private String prependServletPath(String pattern) {
if (this.servletPath == null) {
return pattern;
}
return this.servletPath + pattern;
}

}